예제 #1
0
        def fanseg_extract_func(face_type, *args, **kwargs):
            fanseg = self.fanseg_by_face_type.get(face_type, None)
            if self.fanseg_by_face_type.get(face_type, None) is None:
                cpu_only = len(nn.getCurrentDeviceConfig().devices) == 0

                with nn.tf.device('/CPU:0' if cpu_only else '/GPU:0'):
                    fanseg = TernausNet("FANSeg", self.fanseg_input_size , FaceType.toString( face_type ), place_model_on_cpu=True )

                self.fanseg_by_face_type[face_type] = fanseg
            return fanseg.extract(*args, **kwargs)
예제 #2
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb

        if lowest_vram >= 4:
            suggest_batch_size = 8
        else:
            suggest_batch_size = 4

        yn_str = {True: 'y', False: 'n'}
        ask_override = self.ask_override()

        resolution = default_resolution = self.options[
            'resolution'] = self.load_or_def_option('resolution', 512)

        if self.is_first_run() or ask_override:
            self.ask_batch_size(suggest_batch_size)
            resolution = io.input_int("Resolution",
                                      default_resolution,
                                      add_info="64-1024")

        self.stage_max = stage_max = np.clip(
            mathlib.get_power_of_two(resolution), 6, 10) - 2
        self.options['resolution'] = resolution = 2**(stage_max + 2)

        default_stage = self.load_or_def_option('stage', 0)
        default_target_stage_iter = self.load_or_def_option(
            'target_stage_iter', self.iter + 100000)

        if (self.is_first_run() or ask_override):
            new_stage = np.clip(
                io.input_int("Stage", default_stage,
                             add_info=f"0-{stage_max}"), 0, stage_max)
            if new_stage != default_stage:
                self.options['start_stage_iter'] = self.iter
                default_target_stage_iter = self.iter + 100000
            self.options['stage'] = new_stage
        else:
            self.options['stage'] = default_stage

        if self.options['stage'] == 0:
            if 'start_stage_iter' in self.options:
                self.options.pop('start_stage_iter')

            if 'target_stage_iter' in self.options:
                self.options.pop('target_stage_iter')
        else:
            if (self.is_first_run() or ask_override):
                self.options['target_stage_iter'] = io.input_int(
                    "Target stage iteration", default_target_stage_iter)
            else:
                self.options['target_stage_iter'] = default_target_stage_iter
예제 #3
0
파일: Model.py 프로젝트: rumax/DeepFaceLab
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()
        yn_str = {True:'y',False:'n'}

        #default_resolution = 256
  
        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_target_iter()
            self.ask_batch_size(24)
예제 #4
0
def depth_to_space(x, size):
    if nn.data_format == "NHWC":
        # match NCHW version in order to switch data_format without problems

        b, h, w, c = x.shape.as_list()
        oh, ow = h * size, w * size
        oc = c // (size * size)

        x = tf.reshape(x, (
            -1,
            h,
            w,
            size,
            size,
            oc,
        ))
        x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
        x = tf.reshape(x, (
            -1,
            oh,
            ow,
            oc,
        ))
        return x
    else:
        cfg = nn.getCurrentDeviceConfig()
        if not cfg.cpu_only:
            return tf.depth_to_space(x, size, data_format=nn.data_format)
        b, c, h, w = x.shape.as_list()
        oh, ow = h * size, w * size
        oc = c // (size * size)

        x = tf.reshape(x, (
            -1,
            size,
            size,
            oc,
            h,
            w,
        ))
        x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
        x = tf.reshape(x, (-1, oc, oh, ow))
        return x
예제 #5
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()
        yn_str = {True: 'y', False: 'n'}

        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_target_iter()
            self.ask_batch_size(24)

        default_lr_dropout = self.options[
            'lr_dropout'] = self.load_or_def_option('lr_dropout', False)

        if self.is_first_run() or ask_override:
            self.options['lr_dropout'] = io.input_bool(
                "Use learning rate dropout",
                default_lr_dropout,
                help_message=
                "When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations."
            )
예제 #6
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()
        
        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb
            
        if lowest_vram >= 4:
            suggest_batch_size = 8
        else:
            suggest_batch_size = 4
        
        yn_str = {True:'y',False:'n'}
        ask_override = self.ask_override()

        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_random_flip()
            self.ask_batch_size(suggest_batch_size)

        default_resolution         = self.options['resolution']         = self.load_or_def_option('resolution', 128)
        default_face_type          = self.options['face_type']          = self.load_or_def_option('face_type', 'f')
        default_models_opt_on_gpu  = self.options['models_opt_on_gpu']  = self.load_or_def_option('models_opt_on_gpu', True)
        default_archi              = self.options['archi']              = self.load_or_def_option('archi', 'dfhd')
        default_ae_dims            = self.options['ae_dims']            = self.load_or_def_option('ae_dims', 256)
        default_e_dims             = self.options['e_dims']             = self.load_or_def_option('e_dims', 64)
        default_d_dims             = self.options['d_dims']             = self.load_or_def_option('d_dims', 64)
        
        default_d_mask_dims        = default_d_dims // 3
        default_d_mask_dims        += default_d_mask_dims % 2
        default_d_mask_dims        = self.options['d_mask_dims']        = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
        
        default_learn_mask         = self.options['learn_mask']         = self.load_or_def_option('learn_mask', True)
        default_lr_dropout         = self.options['lr_dropout']         = self.load_or_def_option('lr_dropout', False)
        default_random_warp        = self.options['random_warp']        = self.load_or_def_option('random_warp', True)
        default_true_face_training = self.options['true_face_training'] = self.load_or_def_option('true_face_training', False)
        default_face_style_power   = self.options['face_style_power']   = self.load_or_def_option('face_style_power', 0.0)
        default_bg_style_power     = self.options['bg_style_power']     = self.load_or_def_option('bg_style_power', 0.0)
        default_ct_mode            = self.options['ct_mode']            = self.load_or_def_option('ct_mode', 'none')
        default_clipgrad           = self.options['clipgrad']           = self.load_or_def_option('clipgrad', False)
        default_pretrain           = self.options['pretrain']           = self.load_or_def_option('pretrain', False)

        if self.is_first_run():
            resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
            resolution = np.clip ( (resolution // 16) * 16, 64, 256)
            self.options['resolution'] = resolution
            self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()

        if (self.is_first_run() or ask_override) and len(device_config.devices) == 1:
            self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")

        if self.is_first_run():
            self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
            self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
            
            e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['e_dims'] = e_dims + e_dims % 2
            
            d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['d_dims'] = d_dims + d_dims % 2
            
            d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
            
        if self.is_first_run() or ask_override:
            self.options['learn_mask']  = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case merger forced to use 'not predicted mask' that is not smooth as predicted.")
            self.options['lr_dropout']  = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
            self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")

            if 'df' in self.options['archi']:
                self.options['true_face_training'] = io.input_bool ("Enable 'true face' training", default_true_face_training, help_message="The result face will be more like src and will get extra sharpness. Enable it for last 10-20k iterations before conversion.")
            else:
                self.options['true_face_training'] = False

            self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
            self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
            self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
            self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
            self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")

        if self.options['pretrain'] and self.get_pretraining_data_path() is None:
            raise Exception("pretraining_data_path is not defined") 

        self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
        
        if self.pretrain_just_disabled:
            self.set_iter(1)
예제 #7
0
    def on_initialize(self):
        nn.initialize()
        tf = nn.tf
        nn.set_floatx(tf.float32)
        conv_kernel_initializer = nn.initializers.ca
        
        class Downscale(nn.ModelBase):
            def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs ):                
                self.conv1 = nn.Conv2D( self.in_ch, 
                                          self.out_ch // (4 if self.subpixel else 1),  
                                          kernel_size=self.kernel_size, 
                                          strides=1 if self.subpixel else 2,
                                          padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer )

            def forward(self, x):
                x = self.conv1(x)
                
                if self.subpixel:
                    x = tf.nn.space_to_depth(x, 2)
                
                if self.use_activator:
                    x = tf.nn.leaky_relu(x, 0.1)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
                self.downs = []
                
                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch*( min(2**i, 8)  )
                    self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
                    last_ch = self.downs[-1].get_out_ch()
                    
            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x
                
        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = tf.nn.depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp + x, 0.2)
                return x

        class UpdownResidualBlock(nn.ModelBase):
            def on_build(self, ch, inner_ch, kernel_size=3 ):
                self.up   = Upscale (ch, inner_ch, kernel_size=kernel_size)
                self.res  = ResidualBlock (inner_ch, kernel_size=kernel_size)
                self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)

            def forward(self, inp):
                x = self.up(inp)
                x = upx = self.res(x)
                x = self.down(x)
                x = x + inp
                x = tf.nn.leaky_relu(x, 0.2)
                return x, upx

        class Encoder(nn.ModelBase):                
            def on_build(self, in_ch, e_ch, is_hd):               
                self.is_hd=is_hd                
                if self.is_hd:
                    self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
                    self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
                    self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
                    self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
                else:
                    self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
                    
            def forward(self, inp):
                if self.is_hd:
                    x = tf.concat([ nn.tf_flatten(self.down1(inp)),
                                    nn.tf_flatten(self.down2(inp)),
                                    nn.tf_flatten(self.down3(inp)),
                                    nn.tf_flatten(self.down4(inp)) ], -1 )
                else:
                    x = nn.tf_flatten(self.down1(inp))
                    
                return x
                
        class Inter(nn.ModelBase):
            def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs):
                self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
                super().__init__(**kwargs)
                
            def on_build(self):
                in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch

                self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
                self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
                self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

            def forward(self, inp):
                x = self.dense1(inp)
                x = self.dense2(x)
                x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch))
                x = self.upscale1(x)
                return x
                
            def get_out_ch(self):
                return self.ae_out_ch
                        
        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
                self.is_hd = is_hd

                self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
                self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
                self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)        
 
                if is_hd:
                    self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) 
                    self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) 
                    self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) 
                    self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
                else:
                    self.res0 = ResidualBlock(d_ch*8, kernel_size=3) 
                    self.res1 = ResidualBlock(d_ch*4, kernel_size=3) 
                    self.res2 = ResidualBlock(d_ch*2, kernel_size=3)

                self.out_conv  = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
                
                self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
                self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
                self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)         
                self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
            
            def get_weights_ex(self, include_mask):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights() 

                weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
                          + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()
                            
                if include_mask:
                    weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
                               + self.out_convm.get_weights()                   
                return weights
                
            
            def forward(self, inp):
                z = inp
                                
                if self.is_hd:
                    x, upx = self.res0(z)                                                   
                    x = self.upscale0(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)                    
                    x, upx = self.res1(x)
                    
                    x = self.upscale1(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)                    
                    x, upx = self.res2(x)
                    
                    x = self.upscale2(x)
                    x = tf.nn.leaky_relu(x + upx, 0.2)                    
                    x, upx = self.res3(x)                
                else:
                    x = self.upscale0(z)
                    x = self.res0(x)
                    x = self.upscale1(x)
                    x = self.res1(x)
                    x = self.upscale2(x)
                    x = self.res2(x)

                m = self.upscalem0(z)
                m = self.upscalem1(m)
                m = self.upscalem2(m)
                
                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(m))

        class CodeDiscriminator(nn.ModelBase):
            def on_build(self, in_ch, code_res, ch=256):
                n_downscales = 2 + code_res // 8

                self.convs = []
                prev_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch * min( (2**i), 8 )
                    self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) )
                    prev_ch = cur_ch

                self.out_conv =  nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                for conv in self.convs:
                    x = tf.nn.leaky_relu( conv(x), 0.1 )
                return self.out_conv(x)

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        resolution = self.options['resolution']
        learn_mask = self.options['learn_mask']
        archi = self.options['archi']
        ae_dims = self.options['ae_dims']        
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims'] 
        self.pretrain = self.options['pretrain']
        
        masked_training = True

        models_opt_on_gpu = False if len(devices) != 1 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_nc = 3
        output_nc = 3
        bgr_shape = (resolution, resolution, output_nc)
        mask_shape = (resolution, resolution, 1)
        lowest_dense_res = resolution // 16

        self.model_filename_list = []


        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape)
            self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape)

            self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape)
            self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape)

            self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape)
            self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape)
            
            self.target_dst_0 = tf.placeholder (tf.float32, (None,)+bgr_shape)
            self.target_dst_1 = tf.placeholder (tf.float32, (None,)+bgr_shape)
            self.target_dst_2 = tf.placeholder (tf.float32, (None,)+bgr_shape)
            
        # Initializing model classes
        with tf.device (models_opt_device):
            if 'df' in archi:
                self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')                
                encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1]
                
                self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
                inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
                
                self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src')
                self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst')

                self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                              [self.inter,       'inter.npy'      ],
                                              [self.decoder_src, 'decoder_src.npy'],
                                              [self.decoder_dst, 'decoder_dst.npy']  ]

                if self.is_training:
                    if self.options['true_face_training']:
                        self.dis = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' )
                        self.model_filename_list += [ [self.dis, 'dis.npy'] ]

            elif 'liae' in archi:
                self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
                encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1]
                
                self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
                self.inter_B  = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
                
                inter_AB_out_ch = self.inter_AB.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
                inter_B_out_ch = self.inter_B.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
                inters_out_ch = inter_AB_out_ch+inter_B_out_ch
                
                self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder')
                    
                self.model_filename_list += [ [self.encoder,  'encoder.npy'],
                                              [self.inter_AB, 'inter_AB.npy'],
                                              [self.inter_B , 'inter_B.npy'],
                                              [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3 if self.options['lr_dropout'] else 1.0
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0
                self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
                if 'df' in archi:
                    self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask)
                    self.src_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(False)

                elif 'liae' in archi:
                    self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask)

                self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
                
                if self.options['true_face_training']:
                    self.D_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_opt')
                    self.D_opt.initialize_variables ( self.dis.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ (self.D_opt, 'D_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            
            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []
            gpu_D_loss_gvs = []
            gpu_var_loss_gvs = []
            
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
                    batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src   = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst   = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src   = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst   = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm  = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_dstm  = self.target_dstm[batch_slice,:,:,:]
                        gpu_target_dst_0 = self.target_dst_0[batch_slice,:,:,:]
                        gpu_target_dst_1 = self.target_dst_1[batch_slice,:,:,:]
                        gpu_target_dst_2 = self.target_dst_2[batch_slice,:,:,:]

                    # process model tensors
                    if 'df' in archi:
                        gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                        gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                        
                        gpu_dst_0_code   = self.inter(self.encoder(gpu_target_dst_0))
                        gpu_dst_1_code   = self.inter(self.encoder(gpu_target_dst_1))
                        gpu_dst_2_code   = self.inter(self.encoder(gpu_target_dst_2))
                        
                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                        
                        gpu_pred_src_dst_0, _ = self.decoder_src(gpu_dst_0_code)
                        gpu_pred_src_dst_1, _ = self.decoder_src(gpu_dst_1_code)
                        gpu_pred_src_dst_2, _ = self.decoder_src(gpu_dst_2_code)
                        
                    elif 'liae' in archi:
                        gpu_src_code = self.encoder (gpu_warped_src)
                        gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
                        gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code],-1)
                        gpu_dst_code = self.encoder (gpu_warped_dst)
                        gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                        gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                        gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1)
                        gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1)

                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                            
                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)
                    
                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
                    
                    gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_target_srcmasked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_src_loss =  tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
                    if learn_mask:
                        gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
 
                    face_style_power = self.options['face_style_power'] / 100.0
                    if face_style_power != 0 and not self.pretrain:
                        gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)

                    bg_style_power = self.options['bg_style_power'] / 100.0
                    if bg_style_power != 0 and not self.pretrain:
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) 
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )

                    gpu_dst_loss  = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) 
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    if learn_mask:
                        gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]
                    
                    gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss

                    if self.options['true_face_training']:
                        def DLoss(labels,logits):
                            return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])

                        gpu_src_code_d = self.dis( gpu_src_code )
                        gpu_src_code_d_ones = tf.ones_like(gpu_src_code_d)
                        gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
                        gpu_dst_code_d = self.dis( gpu_dst_code )
                        gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)
 
                        gpu_src_dst_loss += 0.01*DLoss(gpu_src_code_d_ones, gpu_src_code_d)

                        gpu_D_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
                                      DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5

                        gpu_D_loss_gvs += [ nn.tf_gradients (gpu_D_loss, self.dis.get_weights() ) ]

                    gpu_src_dst_loss_gvs += [ nn.tf_gradients ( gpu_src_dst_loss, self.src_dst_trainable_weights ) ]
                    

                    gpu_var_loss  = nn.tf_style_loss (gpu_pred_src_dst_1, gpu_pred_src_dst_0, gaussian_blur_radius=resolution//16, loss_weight=100)
                    gpu_var_loss += nn.tf_style_loss (gpu_pred_src_dst_1, gpu_pred_src_dst_2, gaussian_blur_radius=resolution//16, loss_weight=100)
                    gpu_var_loss_gvs += [ nn.tf_gradients ( gpu_var_loss, self.src_trainable_weights ) ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                if gpu_count == 1:
                    pred_src_src = gpu_pred_src_src_list[0]
                    pred_dst_dst = gpu_pred_dst_dst_list[0]
                    pred_src_dst = gpu_pred_src_dst_list[0]
                    pred_src_srcm = gpu_pred_src_srcm_list[0]
                    pred_dst_dstm = gpu_pred_dst_dstm_list[0]
                    pred_src_dstm = gpu_pred_src_dstm_list[0]
                    
                    src_loss = gpu_src_losses[0]
                    dst_loss = gpu_dst_losses[0]
                    src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
                    var_loss_gv = gpu_var_loss_gvs[0]
                else:
                    pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
                    pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
                    pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
                    pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
                    pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
                    pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)
                    
                    src_loss = nn.tf_average_tensor_list(gpu_src_losses)
                    dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
                    src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
                    var_loss_gv = nn.tf_average_gv_list (gpu_var_loss_gvs)
                    
                if self.options['true_face_training']:
                    D_loss_gv = nn.tf_average_gv_list(gpu_D_loss_gvs)
                    
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv )
                
                if self.options['true_face_training']:
                    D_loss_gv_op = self.D_opt.get_update_op (D_loss_gv )
                
                var_loss_gv_op = self.src_dst_opt.get_update_op (var_loss_gv )

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       })
                s = np.mean(s)
                d = np.mean(d)
                return s, d
            self.src_dst_train = src_dst_train

            def var_train(target_dst_0, target_dst_1, target_dst_2):
                _ = nn.tf_sess.run ( [ var_loss_gv_op],
                                            feed_dict={self.target_dst_0 :target_dst_0,
                                                       self.target_dst_1 :target_dst_1,
                                                       self.target_dst_2 :target_dst_2,
                                                       })
                                                       
            self.var_train = var_train
            
            if self.options['true_face_training']:
                def D_train(warped_src, warped_dst):
                    nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
                self.D_train = D_train

            if learn_mask:
                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                             feed_dict={self.warped_src:warped_src,
                                                        self.warped_dst:warped_dst})
            else:
                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_src_dst],
                                             feed_dict={self.warped_src:warped_src,
                                                        self.warped_dst:warped_dst})
            self.AE_view = AE_view
        else:
            # Initializing merge function            
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                if 'df' in archi:                
                    gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                    
                elif 'liae' in archi:
                    gpu_dst_code = self.encoder (self.warped_dst)
                    gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                    gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                    gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1)
                    gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1)

                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                    
            if learn_mask:
                def AE_merge( warped_dst):
                    return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
            else:
                def AE_merge( warped_dst):
                    return nn.tf_sess.run ( [gpu_pred_src_dst], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            do_init = self.is_first_run()
            
            if self.pretrain_just_disabled:
                if 'df' in archi:
                    if model == self.inter:
                        do_init = True
                elif 'liae' in archi:
                    if model == self.inter_AB:
                        do_init = True
            
            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
                
            if do_init:
                model.init_weights()

        # initializing sample generators
        
        if self.is_training:
            t = SampleProcessor.Types
            if self.options['face_type'] == 'h':
                face_type = t.FACE_TYPE_HALF
            elif self.options['face_type'] == 'mf':
                face_type = t.FACE_TYPE_MID_FULL
            elif self.options['face_type'] == 'f':
                face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None
            
            t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
            
            cpu_count = multiprocessing.cpu_count()
            
            src_generators_count = cpu_count // 2
            if self.options['ct_mode'] != 'none':
                src_generators_count = int(src_generators_count * 1.5)                
            dst_generators_count = cpu_count - src_generators_count

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution } ],
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_HULL), 'resolution': resolution} ],
                        generators_count=dst_generators_count ),
                             
                    SampleGeneratorFaceTemporal(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=False),
                        output_sample_types = [{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution}],
                        generators_count=dst_generators_count )
                             ])
예제 #8
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb

        if lowest_vram >= 4:
            suggest_batch_size = 8
        else:
            suggest_batch_size = 4

        yn_str = {True:'y',False:'n'}
        min_res = 64
        max_res = 640

        default_resolution         = self.options['resolution']         = self.load_or_def_option('resolution', 224)
        default_face_type          = self.options['face_type']          = self.load_or_def_option('face_type', 'wf')
        default_models_opt_on_gpu  = self.options['models_opt_on_gpu']  = self.load_or_def_option('models_opt_on_gpu', True)

        default_ae_dims            = self.options['ae_dims']            = self.load_or_def_option('ae_dims', 256)
        default_e_dims             = self.options['e_dims']             = self.load_or_def_option('e_dims', 64)
        default_d_dims             = self.options['d_dims']             = self.options.get('d_dims', None)
        default_d_mask_dims        = self.options['d_mask_dims']        = self.options.get('d_mask_dims', None)
        default_morph_factor       = self.options['morph_factor']       = self.options.get('morph_factor', 0.33)
        default_masked_training    = self.options['masked_training']    = self.load_or_def_option('masked_training', True)
        default_eyes_mouth_prio    = self.options['eyes_mouth_prio']    = self.load_or_def_option('eyes_mouth_prio', True)
        default_uniform_yaw        = self.options['uniform_yaw']        = self.load_or_def_option('uniform_yaw', False)

        lr_dropout = self.load_or_def_option('lr_dropout', 'n')
        lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
        default_lr_dropout         = self.options['lr_dropout'] = lr_dropout

        default_random_warp        = self.options['random_warp']        = self.load_or_def_option('random_warp', True)
        default_ct_mode            = self.options['ct_mode']            = self.load_or_def_option('ct_mode', 'none')
        default_clipgrad           = self.options['clipgrad']           = self.load_or_def_option('clipgrad', False)
        default_pretrain           = self.options['pretrain']      = self.load_or_def_option('pretrain', False)


        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_random_src_flip()
            self.ask_random_dst_flip()
            self.ask_batch_size(suggest_batch_size)

        if self.is_first_run():
            resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
            resolution = np.clip ( (resolution // 32) * 32, min_res, max_res)
            self.options['resolution'] = resolution
            self.options['face_type'] = io.input_str ("Face type", default_face_type, ['wf','head'], help_message="whole face / head").lower()


        default_d_dims             = self.options['d_dims']             = self.load_or_def_option('d_dims', 64)

        default_d_mask_dims        = default_d_dims // 3
        default_d_mask_dims        += default_d_mask_dims % 2
        default_d_mask_dims        = self.options['d_mask_dims']        = self.load_or_def_option('d_mask_dims', default_d_mask_dims)

        if self.is_first_run():
            self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )

            e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['e_dims'] = e_dims + e_dims % 2

            d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['d_dims'] = d_dims + d_dims % 2

            d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
            
            morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 )
            self.options['morph_factor'] = morph_factor


        if self.is_first_run() or ask_override:
            if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
                self.options['masked_training']  = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.")

            self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
            self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')

        default_gan_power          = self.options['gan_power']          = self.load_or_def_option('gan_power', 0.0)
        default_gan_patch_size     = self.options['gan_patch_size']     = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
        default_gan_dims           = self.options['gan_dims']           = self.load_or_def_option('gan_dims', 16)

        if self.is_first_run() or ask_override:
            self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")

            self.options['lr_dropout']  = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")

            self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")

            self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 1.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 1.0 )

            if self.options['gan_power'] != 0.0:
                gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
                self.options['gan_patch_size'] = gan_patch_size

                gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-64", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 64 )
                self.options['gan_dims'] = gan_dims

            self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
            self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
            
            self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y")
        
        self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
        self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
예제 #9
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(
            device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256
        self.face_type = FaceType.WHOLE_FACE

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name=f'XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t)
                    gpu_pred_list.append(gpu_pred_t)

                    gpu_loss = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=gpu_target_t, logits=gpu_pred_logits_t),
                        axis=[1, 2, 3])
                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred = nn.concat(gpu_pred_list, 0)
                loss = tf.reduce_mean(gpu_losses)

                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            def train(input_np, target_np):
                l, _ = nn.tf_sess.run([loss, loss_gv_op],
                                      feed_dict={
                                          self.model.input_t: input_np,
                                          self.model.target_t: target_np
                                      })
                return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            srcdst_generator = SampleGeneratorFaceXSeg(
                [self.training_data_src_path, self.training_data_dst_path],
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                resolution=resolution,
                face_type=self.face_type,
                generators_count=src_dst_generators_count,
                data_format=nn.data_format)

            src_generator = SampleGeneratorFace(
                self.training_data_src_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=src_generators_count,
                raise_on_no_data=False)
            dst_generator = SampleGeneratorFace(
                self.training_data_dst_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=dst_generators_count,
                raise_on_no_data=False)

            self.set_training_data_generators(
                [srcdst_generator, src_generator, dst_generator])
예제 #10
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb

        if lowest_vram >= 4:
            suggest_batch_size = 8
        else:
            suggest_batch_size = 4

        yn_str = {True:'y',False:'n'}
        min_res = 64
        max_res = 640

        default_resolution         = self.options['resolution']         = self.load_or_def_option('resolution', 128)
        default_face_type          = self.options['face_type']          = self.load_or_def_option('face_type', 'f')
        default_models_opt_on_gpu  = self.options['models_opt_on_gpu']  = self.load_or_def_option('models_opt_on_gpu', True)

        archi = self.load_or_def_option('archi', 'df')
        archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
        default_archi              = self.options['archi'] = archi

        default_ae_dims            = self.options['ae_dims']            = self.load_or_def_option('ae_dims', 256)
        default_e_dims             = self.options['e_dims']             = self.load_or_def_option('e_dims', 64)
        default_d_dims             = self.options['d_dims']             = self.options.get('d_dims', None)
        default_d_mask_dims        = self.options['d_mask_dims']        = self.options.get('d_mask_dims', None)
        default_masked_training    = self.options['masked_training']    = self.load_or_def_option('masked_training', True)
        default_eyes_prio          = self.options['eyes_prio']          = self.load_or_def_option('eyes_prio', False)
        default_uniform_yaw        = self.options['uniform_yaw']        = self.load_or_def_option('uniform_yaw', False)

        lr_dropout = self.load_or_def_option('lr_dropout', 'n')
        lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
        default_lr_dropout         = self.options['lr_dropout'] = lr_dropout

        default_random_warp        = self.options['random_warp']        = self.load_or_def_option('random_warp', True)
        default_gan_power          = self.options['gan_power']          = self.load_or_def_option('gan_power', 0.0)
        default_true_face_power    = self.options['true_face_power']    = self.load_or_def_option('true_face_power', 0.0)
        default_face_style_power   = self.options['face_style_power']   = self.load_or_def_option('face_style_power', 0.0)
        default_bg_style_power     = self.options['bg_style_power']     = self.load_or_def_option('bg_style_power', 0.0)
        default_ct_mode            = self.options['ct_mode']            = self.load_or_def_option('ct_mode', 'none')
        default_clipgrad           = self.options['clipgrad']           = self.load_or_def_option('clipgrad', False)
        default_pretrain           = self.options['pretrain']           = self.load_or_def_option('pretrain', False)

        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_random_flip()
            self.ask_batch_size(suggest_batch_size)

        if self.is_first_run():
            resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.")
            resolution = np.clip ( (resolution // 16) * 16, min_res, max_res)
            self.options['resolution'] = resolution
            self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()

            while True:
                archi = io.input_str ("AE architecture", default_archi, help_message=\
"""
'df' keeps more identity-preserved face.
'liae' can fix overly different face shapes.
'-u' increased likeness of the face.
'-d' (experimental) doubling the resolution using the same computation cost.
Examples: df, liae, df-d, df-ud, liae-ud, ...
""").lower()

                archi_split = archi.split('-')

                if len(archi_split) == 2:
                    archi_type, archi_opts = archi_split
                elif len(archi_split) == 1:
                    archi_type, archi_opts = archi_split[0], None
                else:
                    continue

                if archi_type not in ['df', 'liae']:
                    continue

                if archi_opts is not None:
                    if len(archi_opts) == 0:
                        continue
                    if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0:
                        continue

                    if 'd' in archi_opts:
                        self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res)

                break
            self.options['archi'] = archi

        default_d_dims             = self.options['d_dims']             = self.load_or_def_option('d_dims', 64)

        default_d_mask_dims        = default_d_dims // 3
        default_d_mask_dims        += default_d_mask_dims % 2
        default_d_mask_dims        = self.options['d_mask_dims']        = self.load_or_def_option('d_mask_dims', default_d_mask_dims)

        if self.is_first_run():
            self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )

            e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['e_dims'] = e_dims + e_dims % 2

            d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['d_dims'] = d_dims + d_dims % 2

            d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2

        if self.is_first_run() or ask_override:
            if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
                self.options['masked_training']  = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.")

            self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ')
            self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')

        if self.is_first_run() or ask_override:
            self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")

            self.options['lr_dropout']  = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")

            self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")

            self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Forces the neural network to learn small details of the face. Enable it only when the face is trained enough and don't disable. Typical value is 0.1"), 0.0, 10.0 )

            if 'df' in self.options['archi']:
                self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
            else:
                self.options['true_face_power'] = 0.0

            self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
            self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 )

            self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
            self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")

            self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")

        if self.options['pretrain'] and self.get_pretraining_data_path() is None:
            raise Exception("pretraining_data_path is not defined")

        self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
예제 #11
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        default_resolution         = self.options['resolution']         = self.load_or_def_option('resolution', 224)
        default_face_type          = self.options['face_type']          = self.load_or_def_option('face_type', 'wf')
        default_models_opt_on_gpu  = self.options['models_opt_on_gpu']  = self.load_or_def_option('models_opt_on_gpu', True)

        default_ae_dims            = self.options['ae_dims']            = self.load_or_def_option('ae_dims', 256)

        inter_dims = self.load_or_def_option('inter_dims', None)
        if inter_dims is None:
            inter_dims = self.options['ae_dims']
        default_inter_dims         = self.options['inter_dims'] = inter_dims

        default_e_dims             = self.options['e_dims']             = self.load_or_def_option('e_dims', 64)
        default_d_dims             = self.options['d_dims']             = self.options.get('d_dims', None)
        default_d_mask_dims        = self.options['d_mask_dims']        = self.options.get('d_mask_dims', None)
        default_morph_factor       = self.options['morph_factor']       = self.options.get('morph_factor', 0.5)
        default_uniform_yaw        = self.options['uniform_yaw']        = self.load_or_def_option('uniform_yaw', False)

        default_random_warp        = self.options['random_warp']        = self.load_or_def_option('random_warp', True)
        default_ct_mode            = self.options['ct_mode']            = self.load_or_def_option('ct_mode', 'none')
        default_clipgrad           = self.options['clipgrad']           = self.load_or_def_option('clipgrad', False)

        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_random_src_flip()
            self.ask_random_dst_flip()
            self.ask_batch_size(8)

        if self.is_first_run():
            resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
            resolution = np.clip ( (resolution // 32) * 32, 64, 640)
            self.options['resolution'] = resolution
            self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower()


        default_d_dims             = self.options['d_dims']             = self.load_or_def_option('d_dims', 64)

        default_d_mask_dims        = default_d_dims // 3
        default_d_mask_dims        += default_d_mask_dims % 2
        default_d_mask_dims        = self.options['d_mask_dims']        = self.load_or_def_option('d_mask_dims', default_d_mask_dims)

        if self.is_first_run():
            self.options['ae_dims']    = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
            self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 )

            e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['e_dims'] = e_dims + e_dims % 2

            d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['d_dims'] = d_dims + d_dims % 2

            d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2

            morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 )
            self.options['morph_factor'] = morph_factor

        if self.is_first_run() or ask_override:
            self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')

        default_gan_power          = self.options['gan_power']          = self.load_or_def_option('gan_power', 0.0)
        default_gan_patch_size     = self.options['gan_patch_size']     = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
        default_gan_dims           = self.options['gan_dims']           = self.load_or_def_option('gan_dims', 16)

        if self.is_first_run() or ask_override:
            self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")

            self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")

            self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 )

            if self.options['gan_power'] != 0.0:
                gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
                self.options['gan_patch_size'] = gan_patch_size

                gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 )
                self.options['gan_dims'] = gan_dims

            self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
            self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")

        self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
예제 #12
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(
            device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        conv_kernel_initializer = nn.initializers.ca()

        class Downscale(nn.ModelBase):
            def __init__(self,
                         in_ch,
                         out_ch,
                         kernel_size=3,
                         dilations=1,
                         use_activator=True,
                         *kwargs):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs):
                self.conv1 = nn.Conv2D(
                    self.in_ch,
                    self.out_ch,
                    kernel_size=self.kernel_size,
                    strides=2,
                    padding='SAME',
                    dilations=self.dilations,
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                if self.use_activator:
                    x = tf.nn.leaky_relu(x, 0.1)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3):
                self.conv1 = nn.Conv2D(
                    in_ch,
                    out_ch * 4,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, mod=1, kernel_size=3):
                self.conv1 = nn.Conv2D(
                    ch,
                    ch * mod,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D(
                    ch * mod,
                    ch,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.1)
                x = self.conv2(x)
                x = inp + x
                x = tf.nn.leaky_relu(x, 0.1)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):
                self.conv1 = Downscale(in_ch, e_ch)
                self.conv2 = Downscale(e_ch, e_ch * 2)
                self.conv3 = Downscale(e_ch * 2, e_ch * 4)
                self.conv4 = Downscale(e_ch * 4, e_ch * 8)
                self.conv5 = Downscale(e_ch * 8, e_ch * 16)
                self.conv6 = Downscale(e_ch * 16, e_ch * 32)
                self.conv7 = Downscale(e_ch * 32, e_ch * 64)

                self.res1 = ResidualBlock(e_ch)
                self.res2 = ResidualBlock(e_ch * 2)
                self.res3 = ResidualBlock(e_ch * 4)
                self.res4 = ResidualBlock(e_ch * 8)
                self.res5 = ResidualBlock(e_ch * 16)
                self.res6 = ResidualBlock(e_ch * 32)
                self.res7 = ResidualBlock(e_ch * 64)

            def forward(self, inp):
                x = self.conv1(inp)
                x = self.res1(x)
                x = self.conv2(x)
                x = self.res2(x)
                x = self.conv3(x)
                x = self.res3(x)
                x = self.conv4(x)
                x = self.res4(x)
                x = self.conv5(x)
                x = self.res5(x)
                x = self.conv6(x)
                x = self.res6(x)
                x = self.conv7(x)
                x = self.res7(x)
                return x

        class Inter(nn.ModelBase):
            def __init__(self, in_ch, ae_ch, **kwargs):
                self.in_ch, self.ae_ch = in_ch, ae_ch
                super().__init__(**kwargs)

            def on_build(self):
                in_ch, ae_ch = self.in_ch, self.ae_ch

                self.dense_conv1 = nn.Conv2D(
                    in_ch,
                    64,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)
                self.dense_conv2 = nn.Conv2D(
                    64,
                    in_ch,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

                self.conv7 = Upscale(in_ch, in_ch // 2)
                self.conv6 = Upscale(in_ch // 2, in_ch // 4)

            def forward(self, inp):
                x = inp
                x = self.dense_conv1(x)
                x = self.dense_conv2(x)
                x = self.conv7(x)
                x = self.conv6(x)

                return x

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch):
                self.upscale6 = Upscale(in_ch, in_ch // 2)
                self.upscale5 = Upscale(in_ch // 2, in_ch // 4)
                self.upscale4 = Upscale(in_ch // 4, in_ch // 8)
                self.upscale3 = Upscale(in_ch // 8, in_ch // 16)
                self.upscale2 = Upscale(in_ch // 16, in_ch // 32)
                #self.upscale1 = Upscale(in_ch//32, in_ch//64)
                self.out_conv = nn.Conv2D(
                    in_ch // 32,
                    3,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

                self.res61 = ResidualBlock(in_ch // 2, mod=8)
                self.res62 = ResidualBlock(in_ch // 2, mod=8)
                self.res63 = ResidualBlock(in_ch // 2, mod=8)
                self.res51 = ResidualBlock(in_ch // 4, mod=8)
                self.res52 = ResidualBlock(in_ch // 4, mod=8)
                self.res53 = ResidualBlock(in_ch // 4, mod=8)
                self.res41 = ResidualBlock(in_ch // 8, mod=8)
                self.res42 = ResidualBlock(in_ch // 8, mod=8)
                self.res43 = ResidualBlock(in_ch // 8, mod=8)
                self.res31 = ResidualBlock(in_ch // 16, mod=8)
                self.res32 = ResidualBlock(in_ch // 16, mod=8)
                self.res33 = ResidualBlock(in_ch // 16, mod=8)
                self.res21 = ResidualBlock(in_ch // 32, mod=8)
                self.res22 = ResidualBlock(in_ch // 32, mod=8)
                self.res23 = ResidualBlock(in_ch // 32, mod=8)

                m_ch = in_ch // 2
                self.upscalem6 = Upscale(in_ch, m_ch // 2)
                self.upscalem5 = Upscale(m_ch // 2, m_ch // 4)
                self.upscalem4 = Upscale(m_ch // 4, m_ch // 8)
                self.upscalem3 = Upscale(m_ch // 8, m_ch // 16)
                self.upscalem2 = Upscale(m_ch // 16, m_ch // 32)
                #self.upscalem1 = Upscale(m_ch//32, m_ch//64)
                self.out_convm = nn.Conv2D(
                    m_ch // 32,
                    1,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                z = inp
                x = self.upscale6(z)
                x = self.res61(x)
                x = self.res62(x)
                x = self.res63(x)
                x = self.upscale5(x)
                x = self.res51(x)
                x = self.res52(x)
                x = self.res53(x)
                x = self.upscale4(x)
                x = self.res41(x)
                x = self.res42(x)
                x = self.res43(x)
                x = self.upscale3(x)
                x = self.res31(x)
                x = self.res32(x)
                x = self.res33(x)
                x = self.upscale2(x)
                x = self.res21(x)
                x = self.res22(x)
                x = self.res23(x)
                #x = self.upscale1 (x)

                y = self.upscalem6(z)
                y = self.upscalem5(y)
                y = self.upscalem4(y)
                y = self.upscalem3(y)
                y = self.upscalem2(y)
                #y = self.upscalem1 (y)

                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(y))

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        resolution = self.resolution = 128
        ae_dims = 128
        e_dims = 16

        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) >= 1 and all(
            [dev.total_mem_gb >= 4 for dev in devices])
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_ch = 3
        output_ch = 3
        bgr_shape = nn.get4Dshape(resolution, resolution, input_ch)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        self.model_filename_list = []

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder(nn.floatx, bgr_shape)
            self.warped_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_src = tf.placeholder(nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_srcm = tf.placeholder(nn.floatx, mask_shape)
            self.target_dstm = tf.placeholder(nn.floatx, mask_shape)

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')

            self.inter = Inter(in_ch=e_dims * 64, ae_ch=ae_dims, name='inter')

            self.decoder_src = Decoder(in_ch=e_dims * 16, name='decoder_src')
            self.decoder_dst = Decoder(in_ch=e_dims * 16, name='decoder_dst')

            self.model_filename_list += [[self.encoder, 'encoder.npy'],
                                         [self.inter, 'inter.npy'],
                                         [self.decoder_src, 'decoder_src.npy'],
                                         [self.decoder_dst, 'decoder_dst.npy']]

            if self.is_training:
                self.src_dst_trainable_weights = self.encoder.get_weights(
                ) + self.inter.get_weights() + self.decoder_src.get_weights(
                ) + self.decoder_dst.get_weights()

                # Initialize optimizers
                self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt')
                self.src_dst_opt.initialize_variables(
                    self.src_dst_trainable_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, 4 // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):
                    batch_slice = slice(gpu_id * bs_per_gpu,
                                        (gpu_id + 1) * bs_per_gpu)
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_warped_dst = self.warped_dst[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]
                        gpu_target_srcm = self.target_srcm[
                            batch_slice, :, :, :]
                        gpu_target_dstm = self.target_dstm[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_src_code = self.inter(self.encoder(gpu_warped_src))
                    gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(
                        gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(
                        gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                        gpu_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.gaussian_blur(
                        gpu_target_srcm, max(1, resolution // 32))
                    gpu_target_dstm_blur = nn.gaussian_blur(
                        gpu_target_dstm, max(1, resolution // 32))

                    gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_src_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_src_masked_opt,
                                      gpu_pred_src_src_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_src_masked_opt -
                                       gpu_pred_src_src_masked_opt),
                        axis=[1, 2, 3])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm),
                        axis=[1, 2, 3])

                    gpu_dst_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_dst_masked_opt,
                                      gpu_pred_dst_dst_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dst_masked_opt -
                                       gpu_pred_dst_dst_masked_opt),
                        axis=[1, 2, 3])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm),
                        axis=[1, 2, 3])

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [
                        nn.gradients(gpu_G_loss,
                                     self.src_dst_trainable_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

                src_loss = nn.average_tensor_list(gpu_src_losses)
                dst_loss = nn.average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op(
                    src_dst_loss_gv)

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run(
                    [src_loss, dst_loss, src_dst_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.target_srcm: target_srcm,
                        self.warped_dst: warped_dst,
                        self.target_dst: target_dst,
                        self.target_dstm: target_dstm,
                    })
                s = np.mean(s)
                d = np.mean(d)
                return s, d

            self.src_dst_train = src_dst_train

            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run([
                    pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst,
                    pred_src_dstm
                ],
                                      feed_dict={
                                          self.warped_src: warped_src,
                                          self.warped_dst: warped_dst
                                      })

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge(warped_dst):

                return nn.tf_sess.run(
                    [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm],
                    feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init and self.pretrained_model_path is not None:
                pretrained_filepath = self.pretrained_model_path / filename
                if pretrained_filepath.exists():
                    do_init = not model.load_weights(pretrained_filepath)

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'types':
                        (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution,
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution,
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type,
                                  t.MODE_FACE_MASK_ALL_HULL),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'types':
                        (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type,
                                  t.MODE_FACE_MASK_ALL_HULL),
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }],
                    generators_count=dst_generators_count)
            ])

            self.last_samples = None
예제 #13
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW" if len(
            devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        class EncBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch, level):
                self.zero_level = level == 0
                self.conv1 = nn.Conv2D(in_ch,
                                       out_ch,
                                       kernel_size=3,
                                       padding='SAME')
                self.conv2 = nn.Conv2D(
                    out_ch,
                    out_ch,
                    kernel_size=4 if self.zero_level else 3,
                    padding='VALID' if self.zero_level else 'SAME')

            def forward(self, x):
                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)

                if not self.zero_level:
                    x = nn.max_pool(x)

                #if self.zero_level:

                return x

        class DecBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch, level):
                self.zero_level = level == 0
                self.conv1 = nn.Conv2D(
                    in_ch,
                    out_ch,
                    kernel_size=4 if self.zero_level else 3,
                    padding=3 if self.zero_level else 'SAME')
                self.conv2 = nn.Conv2D(out_ch,
                                       out_ch,
                                       kernel_size=3,
                                       padding='SAME')

            def forward(self, x):
                if not self.zero_level:
                    x = nn.upsample2d(x)

                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)
                return x

        class InterBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch, level):
                self.zero_level = level == 0
                self.dense1 = nn.Dense()

            def forward(self, x):
                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)

                if not self.zero_level:
                    x = nn.max_pool(x)

                #if self.zero_level:

                return x

        class FromRGB(nn.ModelBase):
            def on_build(self, out_ch):
                self.conv1 = nn.Conv2D(3,
                                       out_ch,
                                       kernel_size=1,
                                       padding='SAME')

            def forward(self, x):
                return tf.nn.leaky_relu(self.conv1(x), 0.2)

        class ToRGB(nn.ModelBase):
            def on_build(self, in_ch):
                self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME')
                self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME')

            def forward(self, x):
                return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid(
                    self.convm(x))

        ed_dims = 16
        ae_res = 4
        level_chs = {
            i - 1: v
            for i, v in enumerate([
                np.clip(ed_dims * (2**i), 0, 512)
                for i in range(self.stage_max + 2)
            ][::-1])
        }
        ae_ch = level_chs[0]

        class Encoder(nn.ModelBase):
            def on_build(self, e_ch, levels):
                self.enc_blocks = {}
                self.from_rgbs = {}

                self.dense_norm = nn.DenseNorm()

                for level in range(levels, -1, -1):
                    self.from_rgbs[level] = FromRGB(level_chs[level])
                    if level != 0:
                        self.enc_blocks[level] = EncBlock(
                            level_chs[level], level_chs[level - 1], level)

                self.ae_dense1 = nn.Dense(ae_res * ae_res * ae_ch, 256)
                self.ae_dense2 = nn.Dense(256, ae_res * ae_res * ae_ch)

            def forward(self, stage, inp, prev_inp=None, alpha=None):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.from_rgbs:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        elif level == stage - 1:
                            x = x * alpha + self.from_rgbs[level](prev_inp) * (
                                1 - alpha)

                        if level != 0:
                            x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = self.ae_dense1(x)
                x = self.ae_dense2(x)
                x = nn.reshape_4D(x, ae_res, ae_res, ae_ch)

                return x

            def get_stage_weights(self, stage):
                self.get_weights()
                weights = []
                for level in range(stage, -1, -1):
                    if stage in self.from_rgbs:
                        if level == stage or level == stage - 1:
                            weights.append(self.from_rgbs[level].get_weights())
                        if level != 0:
                            weights.append(
                                self.enc_blocks[level].get_weights())
                weights.append(self.ae_dense1.get_weights())
                weights.append(self.ae_dense2.get_weights())

                if len(weights) == 0:
                    return []
                elif len(weights) == 1:
                    return weights[0]
                else:
                    return sum(weights[1:], weights[0])

        class Decoder(nn.ModelBase):
            def on_build(self, levels_range):

                self.dec_blocks = {}
                self.to_rgbs = {}

                for level in range(levels_range[0], levels_range[1] + 1):
                    self.to_rgbs[level] = ToRGB(level_chs[level])
                    if level != 0:
                        self.dec_blocks[level] = DecBlock(
                            level_chs[level - 1], level_chs[level], level)

            def forward(self, stage, inp, alpha=None, inter=None):
                x = inp

                for level in range(stage + 1):
                    if level in self.to_rgbs:

                        if level == stage and stage > 0:
                            prev_level = level - 1
                            #prev_x, prev_xm = (inter.to_rgbs[prev_level] if inter is not None and prev_level in inter.to_rgbs else self.to_rgbs[prev_level])(x)
                            prev_x, prev_xm = self.to_rgbs[prev_level](x)

                            prev_x = nn.upsample2d(prev_x)
                            prev_xm = nn.upsample2d(prev_xm)

                        if level != 0:
                            x = self.dec_blocks[level](x)

                        if level == stage:
                            x, xm = self.to_rgbs[level](x)
                            if stage > 0:
                                x = x * alpha + prev_x * (1 - alpha)
                                xm = xm * alpha + prev_xm * (1 - alpha)
                            return x, xm
                return x

            def get_stage_weights(self, stage):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights()

                weights = []
                for level in range(stage + 1):
                    if level in self.to_rgbs:
                        if level != 0:
                            weights.append(
                                self.dec_blocks[level].get_weights())
                        if level == stage or level == stage - 1:
                            weights.append(self.to_rgbs[level].get_weights())

                if len(weights) == 0:
                    return []
                elif len(weights) == 1:
                    return weights[0]
                else:
                    return sum(weights[1:], weights[0])

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.stage = stage = self.options['stage']
        self.start_stage_iter = self.options.get('start_stage_iter', 0)
        self.target_stage_iter = self.options.get('target_stage_iter', 0)

        resolution = self.options['resolution']
        stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)]
        stage_resolution = stage_resolutions[stage]
        prev_stage = stage - 1 if stage != 0 else stage
        prev_stage_resolution = stage_resolutions[
            stage - 1] if stage != 0 else stage_resolution

        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_nc = 3
        output_nc = 3
        prev_bgr_shape = nn.get4Dshape(prev_stage_resolution,
                                       prev_stage_resolution, output_nc)
        bgr_shape = nn.get4Dshape(stage_resolution, stage_resolution,
                                  output_nc)
        mask_shape = nn.get4Dshape(stage_resolution, stage_resolution, 1)

        self.model_filename_list = []

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.prev_warped_src = tf.placeholder(tf.float32, prev_bgr_shape)
            self.warped_src = tf.placeholder(tf.float32, bgr_shape)
            self.prev_warped_dst = tf.placeholder(tf.float32, prev_bgr_shape)
            self.warped_dst = tf.placeholder(tf.float32, bgr_shape)

            self.target_src = tf.placeholder(tf.float32, bgr_shape)
            self.target_dst = tf.placeholder(tf.float32, bgr_shape)

            self.target_srcm = tf.placeholder(tf.float32, mask_shape)
            self.target_dstm = tf.placeholder(tf.float32, mask_shape)
            self.alpha_t = tf.placeholder(tf.float32, (None, 1, 1, 1))

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = Encoder(e_ch=ed_dims,
                                   levels=self.stage_max,
                                   name='encoder')

            #self.inter = Decoder(d_ch=ed_dims, total_levels=self.stage_max, levels_range=[0,2], name='inter')
            self.decoder_src = Decoder(levels_range=[0, self.stage_max],
                                       name='decoder_src')
            self.decoder_dst = Decoder(levels_range=[0, self.stage_max],
                                       name='decoder_dst')

            self.model_filename_list += [
                [self.encoder, 'encoder.npy'],
                #[self.inter,       'inter.npy'      ],
                [self.decoder_src, 'decoder_src.npy'],
                [self.decoder_dst, 'decoder_dst.npy']
            ]

            if self.is_training:
                self.src_dst_all_weights = self.encoder.get_weights(
                ) + self.decoder_src.get_weights(
                ) + self.decoder_dst.get_weights()
                self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) \
                               + self.decoder_src.get_stage_weights(stage) \
                               + self.decoder_dst.get_stage_weights(stage)

                # Initialize optimizers
                self.src_dst_opt = nn.RMSprop(lr=2e-4,
                                              lr_dropout=1.0,
                                              name='src_dst_opt')
                self.src_dst_opt.initialize_variables(
                    self.src_dst_all_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):
                    batch_slice = slice(gpu_id * bs_per_gpu,
                                        (gpu_id + 1) * bs_per_gpu)
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_prev_warped_src = self.prev_warped_src[
                            batch_slice, :, :, :]
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_prev_warped_dst = self.prev_warped_dst[
                            batch_slice, :, :, :]
                        gpu_warped_dst = self.warped_dst[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]
                        gpu_target_srcm = self.target_srcm[
                            batch_slice, :, :, :]
                        gpu_target_dstm = self.target_dstm[
                            batch_slice, :, :, :]
                        gpu_alpha_t = self.alpha_t[batch_slice, :, :, :]

                    # process model tensors
                    #gpu_src_code     = self.inter(stage, self.encoder(stage, gpu_warped_src, gpu_prev_warped_src, gpu_alpha_t), gpu_alpha_t )
                    #gpu_dst_code     = self.inter(stage, self.encoder(stage, gpu_warped_dst, gpu_prev_warped_dst, gpu_alpha_t), gpu_alpha_t )
                    gpu_src_code = self.encoder(stage, gpu_warped_src,
                                                gpu_prev_warped_src,
                                                gpu_alpha_t)
                    gpu_dst_code = self.encoder(stage, gpu_warped_dst,
                                                gpu_prev_warped_dst,
                                                gpu_alpha_t)

                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(
                        stage, gpu_src_code, gpu_alpha_t)  #, inter=self.inter)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(
                        stage, gpu_dst_code, gpu_alpha_t)  #, inter=self.inter)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                        stage, gpu_dst_code, gpu_alpha_t)  #, inter=self.inter)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.gaussian_blur(
                        gpu_target_srcm, max(1, stage_resolution // 32))
                    gpu_target_dstm_blur = nn.gaussian_blur(
                        gpu_target_dstm, max(1, stage_resolution // 32))

                    gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_src_loss = tf.reduce_mean(
                        10 * tf.square(gpu_target_srcmasked_opt -
                                       gpu_pred_src_src_masked_opt),
                        axis=[1, 2, 3])
                    gpu_src_loss += tf.reduce_mean(
                        tf.square(gpu_target_srcm - gpu_pred_src_srcm),
                        axis=[1, 2, 3])
                    if stage_resolution >= 16:
                        gpu_src_loss += tf.reduce_mean(
                            5 *
                            nn.dssim(gpu_target_srcmasked_opt,
                                     gpu_pred_src_src_masked_opt,
                                     max_val=1.0,
                                     filter_size=int(stage_resolution / 11.6)),
                            axis=[1])
                    if stage_resolution >= 32:
                        gpu_src_loss += tf.reduce_mean(
                            5 *
                            nn.dssim(gpu_target_srcmasked_opt,
                                     gpu_pred_src_src_masked_opt,
                                     max_val=1.0,
                                     filter_size=int(stage_resolution / 23.2)),
                            axis=[1])

                    gpu_dst_loss = tf.reduce_mean(
                        10 * tf.square(gpu_target_dst_masked_opt -
                                       gpu_pred_dst_dst_masked_opt),
                        axis=[1, 2, 3])
                    gpu_dst_loss += tf.reduce_mean(
                        tf.square(gpu_target_dstm - gpu_pred_dst_dstm),
                        axis=[1, 2, 3])
                    if stage_resolution >= 16:
                        gpu_dst_loss += tf.reduce_mean(
                            5 *
                            nn.dssim(gpu_target_dst_masked_opt,
                                     gpu_pred_dst_dst_masked_opt,
                                     max_val=1.0,
                                     filter_size=int(stage_resolution / 11.6)),
                            axis=[1])
                    if stage_resolution >= 32:
                        gpu_dst_loss += tf.reduce_mean(
                            5 *
                            nn.dssim(gpu_target_dst_masked_opt,
                                     gpu_pred_dst_dst_masked_opt,
                                     max_val=1.0,
                                     filter_size=int(stage_resolution / 23.2)),
                            axis=[1])

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [
                        nn.gradients(gpu_src_dst_loss,
                                     self.src_dst_trainable_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                if gpu_count == 1:
                    pred_src_src = gpu_pred_src_src_list[0]
                    pred_dst_dst = gpu_pred_dst_dst_list[0]
                    pred_src_dst = gpu_pred_src_dst_list[0]
                    pred_src_srcm = gpu_pred_src_srcm_list[0]
                    pred_dst_dstm = gpu_pred_dst_dstm_list[0]
                    pred_src_dstm = gpu_pred_src_dstm_list[0]

                    src_loss = gpu_src_losses[0]
                    dst_loss = gpu_dst_losses[0]
                    src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
                else:
                    pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
                    pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
                    pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
                    pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
                    pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
                    pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)

                    src_loss = nn.average_tensor_list(gpu_src_losses)
                    dst_loss = nn.average_tensor_list(gpu_dst_losses)
                    src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs)

                src_dst_loss_gv_op = self.src_dst_opt.get_update_op(
                    src_dst_loss_gv)

            # Initializing training and view functions
            def get_alpha(batch_size):
                alpha = 0
                if self.stage != 0:
                    alpha = (self.iter - self.start_stage_iter) / (
                        self.target_stage_iter - self.start_stage_iter)
                    alpha = np.clip(alpha, 0, 1)
                alpha = np.array([alpha], nn.floatx.as_numpy_dtype).reshape(
                    (1, 1, 1, 1))
                alpha = np.repeat(alpha, batch_size, 0)
                return alpha

            def src_dst_train(prev_warped_src, warped_src, target_src, target_srcm, \
                              prev_warped_dst, warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run(
                    [src_loss, dst_loss, src_dst_loss_gv_op],
                    feed_dict={
                        self.prev_warped_src: prev_warped_src,
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.target_srcm: target_srcm,
                        self.prev_warped_dst: prev_warped_dst,
                        self.warped_dst: warped_dst,
                        self.target_dst: target_dst,
                        self.target_dstm: target_dstm,
                        self.alpha_t: get_alpha(prev_warped_src.shape[0])
                    })
                s = np.mean(s)
                d = np.mean(d)
                return s, d

            self.src_dst_train = src_dst_train

            def AE_view(prev_warped_src, warped_src, prev_warped_dst,
                        warped_dst):
                return nn.tf_sess.run(
                    [
                        pred_src_src, pred_dst_dst, pred_dst_dstm,
                        pred_src_dst, pred_src_dstm
                    ],
                    feed_dict={
                        self.prev_warped_src: prev_warped_src,
                        self.warped_src: warped_src,
                        self.prev_warped_dst: prev_warped_dst,
                        self.warped_dst: warped_dst,
                        self.alpha_t: get_alpha(prev_warped_src.shape[0])
                    })

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code, stage=stage)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code,
                                                        stage=stage)

            def AE_merge(warped_dst):
                return nn.tf_sess.run(
                    [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm],
                    feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            do_init = self.is_first_run()

            if self.pretrain_just_disabled:
                if model == self.inter:
                    do_init = True

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init:
                model.init_weights()

        # initializing sample generators

        if self.is_training:
            self.face_type = FaceType.FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            cpu_count = multiprocessing.cpu_count()

            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count - src_generators_count

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': prev_stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': prev_stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_MASK,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.G,
                            'face_mask_type':
                            SampleProcessor.FaceMaskType.FULL_FACE,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                    ],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': prev_stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': prev_stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_MASK,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.G,
                            'face_mask_type':
                            SampleProcessor.FaceMaskType.FULL_FACE,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution,
                            'nearest_resize_to': stage_resolution
                        },
                    ],
                    generators_count=dst_generators_count)
            ])

            self.last_samples = None
예제 #14
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb

        suggest_batch_size = 4

        yn_str = {True: 'y', False: 'n'}
        default_models_opt_on_gpu = self.options[
            'models_opt_on_gpu'] = self.load_or_def_option(
                'models_opt_on_gpu', False)
        default_ae_dims = self.options['ae_dims'] = self.load_or_def_option(
            'ae_dims', 256)
        default_e_dims = self.options['e_dims'] = self.load_or_def_option(
            'e_dims', 64)
        default_d_dims = self.options['d_dims'] = self.load_or_def_option(
            'd_dims', 64)
        default_d_mask_dims = self.options[
            'd_mask_dims'] = self.load_or_def_option('d_mask_dims', 16)
        default_gan_power = self.options[
            'gan_power'] = self.load_or_def_option('gan_power', 0.0)
        default_clipgrad = self.options['clipgrad'] = self.load_or_def_option(
            'clipgrad', False)

        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_batch_size(suggest_batch_size)

        if self.is_first_run():
            self.options['ae_dims'] = np.clip(
                io.input_int(
                    "AutoEncoder dimensions",
                    default_ae_dims,
                    add_info="32-1024",
                    help_message=
                    "All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU."
                ), 32, 1024)

            e_dims = np.clip(
                io.input_int(
                    "Encoder dimensions",
                    default_e_dims,
                    add_info="16-256",
                    help_message=
                    "More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU."
                ), 16, 256)
            self.options['e_dims'] = e_dims + e_dims % 2

            d_dims = np.clip(
                io.input_int(
                    "Decoder dimensions",
                    default_d_dims,
                    add_info="16-256",
                    help_message=
                    "More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU."
                ), 16, 256)
            self.options['d_dims'] = d_dims + d_dims % 2

            d_mask_dims = np.clip(
                io.input_int(
                    "Decoder mask dimensions",
                    default_d_mask_dims,
                    add_info="16-256",
                    help_message=
                    "Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality."
                ), 16, 256)
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2

        if self.is_first_run() or ask_override:
            if len(device_config.devices) == 1:
                self.options['models_opt_on_gpu'] = io.input_bool(
                    "Place models and optimizer on GPU",
                    default_models_opt_on_gpu,
                    help_message=
                    "When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions."
                )

            self.options['gan_power'] = np.clip(
                io.input_number(
                    "GAN power",
                    default_gan_power,
                    add_info="0.0 .. 10.0",
                    help_message=
                    "Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"
                ), 0.0, 10.0)

            self.options['clipgrad'] = io.input_bool(
                "Enable gradient clipping",
                default_clipgrad,
                help_message=
                "Gradient clipping reduces chance of model collapse, sacrificing speed of training."
            )
예제 #15
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(
            device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        conv_kernel_initializer = nn.initializers.ca()

        class Downscale(nn.ModelBase):
            def __init__(self,
                         in_ch,
                         out_ch,
                         kernel_size=5,
                         dilations=1,
                         subpixel=True,
                         use_activator=True,
                         *kwargs):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs):
                self.conv1 = nn.Conv2D(
                    self.in_ch,
                    self.out_ch // (4 if self.subpixel else 1),
                    kernel_size=self.kernel_size,
                    strides=1 if self.subpixel else 2,
                    padding='SAME',
                    dilations=self.dilations,
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                if self.subpixel:
                    x = nn.space_to_depth(x, 2)
                if self.use_activator:
                    x = tf.nn.leaky_relu(x, 0.1)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self,
                         in_ch,
                         ch,
                         n_downscales,
                         kernel_size,
                         dilations=1,
                         subpixel=True):
                self.downs = []

                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch * (min(2**i, 8))
                    self.downs.append(
                        Downscale(last_ch,
                                  cur_ch,
                                  kernel_size=kernel_size,
                                  dilations=dilations,
                                  subpixel=subpixel))
                    last_ch = self.downs[-1].get_out_ch()

            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3):
                self.conv1 = nn.Conv2D(
                    in_ch,
                    out_ch * 4,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3):
                self.conv1 = nn.Conv2D(
                    ch,
                    ch,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D(
                    ch,
                    ch,
                    kernel_size=kernel_size,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp + x, 0.2)
                return x

        class UpdownResidualBlock(nn.ModelBase):
            def on_build(self, ch, inner_ch, kernel_size=3):
                self.up = Upscale(ch, inner_ch, kernel_size=kernel_size)
                self.res = ResidualBlock(inner_ch, kernel_size=kernel_size)
                self.down = Downscale(inner_ch,
                                      ch,
                                      kernel_size=kernel_size,
                                      use_activator=False)

            def forward(self, inp):
                x = self.up(inp)
                x = upx = self.res(x)
                x = self.down(x)
                x = x + inp
                x = tf.nn.leaky_relu(x, 0.2)
                return x, upx

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):
                self.down1 = DownscaleBlock(in_ch,
                                            e_ch,
                                            n_downscales=5,
                                            kernel_size=5,
                                            dilations=1,
                                            subpixel=False)

            def forward(self, inp):
                x = nn.flatten(self.down1(inp))
                return x

        class Inter(nn.ModelBase):
            def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch,
                         **kwargs):
                self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
                super().__init__(**kwargs)

            def on_build(self):
                in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch

                self.dense1 = nn.Dense(in_ch, ae_ch)
                self.dense2 = nn.Dense(
                    ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch)
                self.upscale1 = Upscale(ae_out_ch, ae_out_ch * 2)

            def forward(self, inp):
                x = self.dense1(inp)
                x = self.dense2(x)
                x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res,
                                  self.ae_out_ch)
                x = self.upscale1(x)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch, d_mask_ch):

                self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3)
                self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3)
                self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3)
                self.upscale3 = Upscale(d_ch * 2, d_ch * 1, kernel_size=3)

                self.res0 = ResidualBlock(d_ch * 8, kernel_size=3)
                self.res1 = ResidualBlock(d_ch * 4, kernel_size=3)
                self.res2 = ResidualBlock(d_ch * 2, kernel_size=3)
                self.res3 = ResidualBlock(d_ch * 1, kernel_size=3)
                self.out_conv = nn.Conv2D(
                    d_ch * 1,
                    3,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

                self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3)
                self.upscalem1 = Upscale(d_mask_ch * 8,
                                         d_mask_ch * 4,
                                         kernel_size=3)
                self.upscalem2 = Upscale(d_mask_ch * 4,
                                         d_mask_ch * 2,
                                         kernel_size=3)
                self.upscalem3 = Upscale(d_mask_ch * 2,
                                         d_mask_ch * 1,
                                         kernel_size=3)
                self.out_convm = nn.Conv2D(
                    d_mask_ch * 1,
                    1,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=conv_kernel_initializer)

            """
            def get_weights_ex(self, include_mask):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights()

                weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
                          + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()

                if include_mask:
                    weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
                               + self.out_convm.get_weights()
                return weights
                
            """

            def get_weights_ex(self, include_mask):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights()

                weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights()\
                          + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.res3.get_weights() + self.out_conv.get_weights()

                if include_mask:
                    weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscalem3.get_weights() \
                               + self.out_convm.get_weights()
                return weights

            def forward(self, inp):
                z = inp

                x = self.upscale0(z)
                x = self.res0(x)
                x = self.upscale1(x)
                x = self.res1(x)
                x = self.upscale2(x)
                x = self.res2(x)
                x = self.upscale3(x)
                x = self.res3(x)

                m = self.upscalem0(z)
                m = self.upscalem1(m)
                m = self.upscalem2(m)
                m = self.upscalem3(m)
                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(m))

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 448
        self.learn_mask = learn_mask = True
        eyes_prio = True
        ae_dims = self.options['ae_dims']
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        self.pretrain = False
        self.pretrain_just_disabled = False
        if self.pretrain_just_disabled:
            self.set_iter(0)

        self.gan_power = gan_power = self.options[
            'gan_power'] if not self.pretrain else 0.0

        masked_training = True

        models_opt_on_gpu = False if len(devices) == 0 else True if len(
            devices) > 1 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_ch = 3
        output_ch = 3
        bgr_shape = nn.get4Dshape(resolution, resolution, input_ch)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)
        lowest_dense_res = resolution // 32

        self.model_filename_list = []

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder(nn.floatx, bgr_shape)
            self.warped_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_src = tf.placeholder(nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_srcm_all = tf.placeholder(nn.floatx, mask_shape)
            self.target_dstm_all = tf.placeholder(nn.floatx, mask_shape)

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
            encoder_out_ch = self.encoder.compute_output_channels(
                (nn.floatx, bgr_shape))

            self.inter = Inter(in_ch=encoder_out_ch,
                               lowest_dense_res=lowest_dense_res,
                               ae_ch=ae_dims,
                               ae_out_ch=ae_dims,
                               name='inter')
            inter_out_ch = self.inter.compute_output_channels(
                (nn.floatx, (None, encoder_out_ch)))

            self.decoder_src = Decoder(in_ch=inter_out_ch,
                                       d_ch=d_dims,
                                       d_mask_ch=d_mask_dims,
                                       name='decoder_src')
            self.decoder_dst = Decoder(in_ch=inter_out_ch,
                                       d_ch=d_dims,
                                       d_mask_ch=d_mask_dims,
                                       name='decoder_dst')

            self.model_filename_list += [[self.encoder, 'encoder.npy'],
                                         [self.inter, 'inter.npy'],
                                         [self.decoder_src, 'decoder_src.npy'],
                                         [self.decoder_dst, 'decoder_dst.npy']]

            if self.is_training:
                if gan_power != 0:
                    self.D_src = nn.PatchDiscriminator(patch_size=resolution //
                                                       16,
                                                       in_ch=output_ch,
                                                       base_ch=256,
                                                       name="D_src")
                    self.D_dst = nn.PatchDiscriminator(patch_size=resolution //
                                                       16,
                                                       in_ch=output_ch,
                                                       base_ch=256,
                                                       name="D_dst")
                    self.model_filename_list += [[self.D_src, 'D_src.npy']]
                    self.model_filename_list += [[self.D_dst, 'D_dst.npy']]

                # Initialize optimizers
                lr = 5e-5
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0
                self.src_dst_opt = nn.RMSprop(lr=lr,
                                              clipnorm=clipnorm,
                                              name='src_dst_opt')
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

                self.src_dst_all_trainable_weights = self.encoder.get_weights(
                ) + self.inter.get_weights() + self.decoder_src.get_weights(
                ) + self.decoder_dst.get_weights()
                self.src_dst_trainable_weights = self.encoder.get_weights(
                ) + self.inter.get_weights() + self.decoder_src.get_weights_ex(
                    learn_mask) + self.decoder_dst.get_weights_ex(learn_mask)

                self.src_dst_opt.initialize_variables(
                    self.src_dst_all_trainable_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)

                if gan_power != 0:
                    self.D_src_dst_opt = nn.RMSprop(lr=lr,
                                                    clipnorm=clipnorm,
                                                    name='D_src_dst_opt')
                    self.D_src_dst_opt.initialize_variables(
                        self.D_src.get_weights() + self.D_dst.get_weights(),
                        vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [(self.D_src_dst_opt,
                                                  'D_src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gvs = []
            gpu_D_code_loss_gvs = []
            gpu_D_src_dst_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_warped_dst = self.warped_dst[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]
                        gpu_target_srcm_all = self.target_srcm_all[
                            batch_slice, :, :, :]
                        gpu_target_dstm_all = self.target_dstm_all[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_src_code = self.inter(self.encoder(gpu_warped_src))
                    gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(
                        gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(
                        gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                        gpu_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    # unpack masks from one combined mask
                    gpu_target_srcm = tf.clip_by_value(gpu_target_srcm_all, 0,
                                                       1)
                    gpu_target_dstm = tf.clip_by_value(gpu_target_dstm_all, 0,
                                                       1)
                    gpu_target_srcm_eyes = tf.clip_by_value(
                        gpu_target_srcm_all - 1, 0, 1)
                    gpu_target_dstm_eyes = tf.clip_by_value(
                        gpu_target_dstm_all - 1, 0, 1)

                    gpu_target_srcm_blur = nn.gaussian_blur(
                        gpu_target_srcm, max(1, resolution // 32))
                    gpu_target_dstm_blur = nn.gaussian_blur(
                        gpu_target_dstm, max(1, resolution // 32))

                    gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_src_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_src_masked_opt,
                                      gpu_pred_src_src_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_src_masked_opt -
                                       gpu_pred_src_src_masked_opt),
                        axis=[1, 2, 3])

                    if eyes_prio:
                        gpu_src_loss += tf.reduce_mean(
                            300 *
                            tf.abs(gpu_target_src * gpu_target_srcm_eyes -
                                   gpu_pred_src_src * gpu_target_srcm_eyes),
                            axis=[1, 2, 3])

                    if learn_mask:
                        gpu_src_loss += tf.reduce_mean(
                            10 *
                            tf.square(gpu_target_srcm - gpu_pred_src_srcm),
                            axis=[1, 2, 3])

                    gpu_dst_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_dst_masked_opt,
                                      gpu_pred_dst_dst_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dst_masked_opt -
                                       gpu_pred_dst_dst_masked_opt),
                        axis=[1, 2, 3])

                    if eyes_prio:
                        gpu_dst_loss += tf.reduce_mean(
                            300 *
                            tf.abs(gpu_target_dst * gpu_target_dstm_eyes -
                                   gpu_pred_dst_dst * gpu_target_dstm_eyes),
                            axis=[1, 2, 3])

                    if learn_mask:
                        gpu_dst_loss += tf.reduce_mean(
                            10 *
                            tf.square(gpu_target_dstm - gpu_pred_dst_dstm),
                            axis=[1, 2, 3])

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss

                    def DLoss(labels, logits):
                        return tf.reduce_mean(
                            tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=labels, logits=logits),
                            axis=[1, 2, 3])

                    if gan_power != 0:
                        gpu_pred_src_src_d = self.D_src(
                            gpu_pred_src_src_masked_opt)
                        gpu_pred_src_src_d_ones = tf.ones_like(
                            gpu_pred_src_src_d)
                        gpu_pred_src_src_d_zeros = tf.zeros_like(
                            gpu_pred_src_src_d)
                        gpu_target_src_d = self.D_src(
                            gpu_target_src_masked_opt)
                        gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d)
                        gpu_pred_dst_dst_d = self.D_dst(
                            gpu_pred_dst_dst_masked_opt)
                        gpu_pred_dst_dst_d_ones = tf.ones_like(
                            gpu_pred_dst_dst_d)
                        gpu_pred_dst_dst_d_zeros = tf.zeros_like(
                            gpu_pred_dst_dst_d)
                        gpu_target_dst_d = self.D_dst(
                            gpu_target_dst_masked_opt)
                        gpu_target_dst_d_ones = tf.ones_like(gpu_target_dst_d)

                        gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones   , gpu_target_src_d) + \
                                              DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \
                                             (DLoss(gpu_target_dst_d_ones   , gpu_target_dst_d) + \
                                              DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5

                        gpu_D_src_dst_loss_gvs += [
                            nn.gradients(
                                gpu_D_src_dst_loss,
                                self.D_src.get_weights() +
                                self.D_dst.get_weights())
                        ]

                        gpu_G_loss += gan_power * (
                            DLoss(gpu_pred_src_src_d_ones,
                                  gpu_pred_src_src_d) +
                            DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))

                    gpu_G_loss_gvs += [
                        nn.gradients(gpu_G_loss,
                                     self.src_dst_trainable_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)
                src_loss = nn.average_tensor_list(gpu_src_losses)
                dst_loss = nn.average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op(
                    nn.average_gv_list(gpu_G_loss_gvs))

                if gan_power != 0:
                    src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op(
                        nn.average_gv_list(gpu_D_src_dst_loss_gvs))

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm_all, \
                              warped_dst, target_dst, target_dstm_all):
                s, d, _ = nn.tf_sess.run(
                    [src_loss, dst_loss, src_dst_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.target_srcm_all: target_srcm_all,
                        self.warped_dst: warped_dst,
                        self.target_dst: target_dst,
                        self.target_dstm_all: target_dstm_all,
                    })
                s = np.mean(s)
                d = np.mean(d)
                return s, d

            self.src_dst_train = src_dst_train

            if gan_power != 0:
                def D_src_dst_train(warped_src, target_src, target_srcm_all, \
                                    warped_dst, target_dst, target_dstm_all):
                    nn.tf_sess.run(
                        [src_D_src_dst_loss_gv_op],
                        feed_dict={
                            self.warped_src: warped_src,
                            self.target_src: target_src,
                            self.target_srcm_all: target_srcm_all,
                            self.warped_dst: warped_dst,
                            self.target_dst: target_dst,
                            self.target_dstm_all: target_dstm_all
                        })

                self.D_src_dst_train = D_src_dst_train

            if learn_mask:

                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run([
                        pred_src_src, pred_dst_dst, pred_dst_dstm,
                        pred_src_dst, pred_src_dstm
                    ],
                                          feed_dict={
                                              self.warped_src: warped_src,
                                              self.warped_dst: warped_dst
                                          })
            else:

                def AE_view(warped_src, warped_dst):
                    return nn.tf_sess.run(
                        [pred_src_src, pred_dst_dst, pred_src_dst],
                        feed_dict={
                            self.warped_src: warped_src,
                            self.warped_dst: warped_dst
                        })

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            if learn_mask:

                def AE_merge(warped_dst):
                    return nn.tf_sess.run([
                        gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm
                    ],
                                          feed_dict={
                                              self.warped_dst: warped_dst
                                          })
            else:

                def AE_merge(warped_dst):
                    return nn.tf_sess.run(
                        [gpu_pred_src_dst],
                        feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_HEAD

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            t_img_warped = t.IMG_WARPED_TRANSFORMED

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'types': (t_img_warped, face_type, t.MODE_BGR),
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'types':
                            (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'types': (t.IMG_TRANSFORMED, face_type,
                                      t.MODE_FACE_MASK_ALL_EYES_HULL),
                            'data_format':
                            nn.data_format,
                            'resolution':
                            resolution
                        },
                    ],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'types': (t_img_warped, face_type, t.MODE_BGR),
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'types':
                            (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'types': (t.IMG_TRANSFORMED, face_type,
                                      t.MODE_FACE_MASK_ALL_EYES_HULL),
                            'data_format':
                            nn.data_format,
                            'resolution':
                            resolution
                        },
                    ],
                    generators_count=dst_generators_count)
            ])

            if self.pretrain_just_disabled:
                self.update_sample_for_preview(force_new=True)
예제 #16
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        input_ch=3
        resolution  = self.resolution = self.options['resolution']
        e_dims      = self.options['e_dims']
        ae_dims     = self.options['ae_dims']
        inter_dims  = self.inter_dims = self.options['inter_dims']
        inter_res   = self.inter_res = resolution // 32
        d_dims      = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        face_type   = self.face_type = {'f'    : FaceType.FULL,
                                        'wf'   : FaceType.WHOLE_FACE,
                                        'head' : FaceType.HEAD}[ self.options['face_type'] ]
        morph_factor = self.options['morph_factor']
        gan_power    = self.gan_power = self.options['gan_power']
        random_warp  = self.options['random_warp']

        blur_out_mask = self.options['blur_out_mask']

        ct_mode = self.options['ct_mode']
        if ct_mode == 'none':
            ct_mode = None

        use_fp16 = False
        if self.is_exporting:
            use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')

        conv_dtype = tf.float16 if use_fp16 else tf.float32

        class Downscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=5 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)

            def forward(self, x):
                return tf.nn.leaky_relu(self.conv1(x), 0.1)

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)

            def forward(self, x):
                x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp+x, 0.2)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self):
                self.down1 = Downscale(input_ch, e_dims, kernel_size=5)
                self.res1 = ResidualBlock(e_dims)
                self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5)
                self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5)
                self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5)
                self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5)
                self.res5 = ResidualBlock(e_dims*8)
                self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )

            def forward(self, x):
                if use_fp16:
                    x = tf.cast(x, tf.float16)
                x = self.down1(x)
                x = self.res1(x)
                x = self.down2(x)
                x = self.down3(x)
                x = self.down4(x)
                x = self.down5(x)
                x = self.res5(x)
                if use_fp16:
                    x = tf.cast(x, tf.float32)
                x = nn.pixel_norm(nn.flatten(x), axes=-1)
                x = self.dense1(x)
                return x


        class Inter(nn.ModelBase):
            def on_build(self):
                self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims)

            def forward(self, inp):
                x = inp
                x = self.dense2(x)
                x = nn.reshape_4D (x, inter_res, inter_res, inter_dims)
                return x


        class Decoder(nn.ModelBase):
            def on_build(self ):
                self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3)
                self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3)
                self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3)
                self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3)

                self.res0 = ResidualBlock(d_dims*8, kernel_size=3)
                self.res1 = ResidualBlock(d_dims*8, kernel_size=3)
                self.res2 = ResidualBlock(d_dims*4, kernel_size=3)
                self.res3 = ResidualBlock(d_dims*2, kernel_size=3)

                self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3)
                self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3)
                self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
                self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
                self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
                self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)

                self.out_conv  = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
                self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
                self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
                self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)

            def forward(self, z):
                if use_fp16:
                    z = tf.cast(z, tf.float16)

                x = self.upscale0(z)
                x = self.res0(x)
                x = self.upscale1(x)
                x = self.res1(x)
                x = self.upscale2(x)
                x = self.res2(x)
                x = self.upscale3(x)
                x = self.res3(x)

                x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
                                                                 self.out_conv1(x),
                                                                 self.out_conv2(x),
                                                                 self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
                m = self.upscalem0(z)
                m = self.upscalem1(m)
                m = self.upscalem2(m)
                m = self.upscalem3(m)
                m = self.upscalem4(m)
                m = tf.nn.sigmoid(self.out_convm(m))

                if use_fp16:
                    x = tf.cast(x, tf.float32)
                    m = tf.cast(m, tf.float32)
                return x, m

        models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        self.model_filename_list = []

        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
            self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')

            self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
            self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')

            self.target_srcm    = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
            self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
            self.target_dstm    = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
            self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')

            self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')

        # Initializing model classes
        with tf.device (models_opt_device):
            self.encoder = Encoder(name='encoder')
            self.inter_src = Inter(name='inter_src')
            self.inter_dst = Inter(name='inter_dst')
            self.decoder = Decoder(name='decoder')

            self.model_filename_list += [   [self.encoder,  'encoder.npy'],
                                            [self.inter_src, 'inter_src.npy'],
                                            [self.inter_dst , 'inter_dst.npy'],
                                            [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                # Initialize optimizers
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0
                lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0

                self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()

                #if random_warp:
                #    self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights()

                self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]

                if gan_power != 0:
                    self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
                    self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt')
                    self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ [self.GAN, 'GAN.npy'],
                                                  [self.GAN_opt, 'GAN_opt.npy'] ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gradients = []
            gpu_GAN_loss_gradients = []

            def DLossOnes(logits):
                return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])

            def DLossZeros(logits):
                return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])

            for gpu_id in range(gpu_count):
                with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src      = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst      = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src      = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst      = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm     = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_srcm_em  = self.target_srcm_em[batch_slice,:,:,:]
                        gpu_target_dstm     = self.target_dstm[batch_slice,:,:,:]
                        gpu_target_dstm_em  = self.target_dstm_em[batch_slice,:,:,:]

                    # process model tensors
                    gpu_src_code = self.encoder (gpu_warped_src)
                    gpu_dst_code = self.encoder (gpu_warped_dst)

                    gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
                    gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)

                    inter_dims_bin = int(inter_dims*morph_factor)
                    with tf.device(f'/CPU:0'):
                        inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
                                                                                    tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)

                        inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])

                    gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
                    gpu_dst_code = gpu_dst_inter_dst_code

                    inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
                    gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0],   [-1, inter_dims_slice , inter_res, inter_res]),
                                                   tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )

                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_anti = 1-gpu_target_srcm
                    gpu_target_dstm_anti = 1-gpu_target_dstm

                    gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
                    gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)

                    gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
                    gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
                    gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
                    gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur

                    if blur_out_mask:
                        sigma = resolution / 128
                        
                        x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma)
                        y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) 
                        y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)                        
                        gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti
                        
                        x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma)
                        y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) 
                        y = tf.where(tf.equal(y, 0), tf.ones_like(y), y)                        
                        gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti

                    gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
                    gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur

                    gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
                    gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
                    gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
                    gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur

                    # Structural loss
                    gpu_src_loss =  tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
                    gpu_dst_loss =  tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])

                    # Pixel loss
                    gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3])
                    gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3])

                    # Eyes+mouth prio loss
                    gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3])
                    gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3])

                    # Mask loss
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]
                    gpu_G_loss = gpu_src_loss + gpu_dst_loss
                    # dst-dst background weak loss
                    gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
                    gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)


                    if gan_power != 0:
                        gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
                        gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
                        gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked)
                        gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked)

                        gpu_GAN_loss = (DLossOnes (gpu_target_src_d)   + DLossOnes (gpu_target_src_d2) + \
                                        DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
                                        DLossOnes (gpu_target_dst_d)   + DLossOnes (gpu_target_dst_d2) + \
                                        DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
                                        ) * (1.0 / 8)

                        gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]

                        gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
                                       DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
                                      ) * gan_power

                        # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
                        gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
                        gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )

                    gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(f'/CPU:0'):
                pred_src_src  = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

            with tf.device (models_opt_device):
                src_loss = tf.concat(gpu_src_losses, 0)
                dst_loss = tf.concat(gpu_dst_losses, 0)
                train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))

                if gan_power != 0:
                    GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )

            # Initializing training and view functions
            def train(warped_src, target_src, target_srcm, target_srcm_em,  \
                              warped_dst, target_dst, target_dstm, target_dstm_em, ):
                s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.target_srcm_em:target_srcm_em,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       self.target_dstm_em:target_dstm_em,
                                                       })
                return s, d
            self.train = train

            if gan_power != 0:
                def GAN_train(warped_src, target_src, target_srcm, target_srcm_em,  \
                              warped_dst, target_dst, target_dstm, target_dstm_em, ):
                    nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src,
                                                               self.target_src :target_src,
                                                               self.target_srcm:target_srcm,
                                                               self.target_srcm_em:target_srcm_em,
                                                               self.warped_dst :warped_dst,
                                                               self.target_dst :target_dst,
                                                               self.target_dstm:target_dstm,
                                                               self.target_dstm_em:target_dstm_em})
                self.GAN_train = GAN_train

            def AE_view(warped_src, warped_dst, morph_value):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })

            self.AE_view = AE_view
        else:
            #Initializing merge function
            with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.encoder (self.warped_dst)
                gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
                gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)

                inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
                gpu_src_dst_code =  tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0],   [-1, inter_dims_slice , inter_res, inter_res]),
                                                 tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )

                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)

            def AE_merge(warped_dst, morph_value):
                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            do_init = self.is_first_run()
            if self.is_training and gan_power != 0 and model == self.GAN:
                if self.gan_model_changed:
                    do_init = True
            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
            if do_init:
                model.init_weights()
        ###############

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain

            cpu_count = multiprocessing.cpu_count()
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            if ct_mode is not None:
                src_generators_count = int(src_generators_count * 1.5)



            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                         'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                         'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE,  'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                             'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                             'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE,  'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
                        generators_count=dst_generators_count )
                             ])

            self.last_src_samples_loss = []
            self.last_dst_samples_loss = []
예제 #17
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        self.resolution = resolution = self.options['resolution']
        self.face_type = {'h'  : FaceType.HALF,
                          'mf' : FaceType.MID_FULL,
                          'f'  : FaceType.FULL,
                          'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]

        eyes_prio = self.options['eyes_prio']
        archi = self.options['archi']
        is_hd = 'hd' in archi
        ae_dims = self.options['ae_dims']
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0

        masked_training = self.options['masked_training']
        ct_mode = self.options['ct_mode']
        if ct_mode == 'none':
            ct_mode = None

        models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch=3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        self.model_filename_list = []

        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
            self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
            
            self.src_code_in = tf.placeholder (nn.floatx, (None,256) )

            self.target_src = tf.placeholder (nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder (nn.floatx, bgr_shape)

            self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape)
            self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)
            
        # Initializing model classes
        model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None)  
        
        with tf.device (models_opt_device):
            if 'df' in archi:
                self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))

                self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
                inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))

                self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
                self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')

                self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                              [self.inter,       'inter.npy'      ],
                                              [self.decoder_src, 'decoder_src.npy'],
                                              [self.decoder_dst, 'decoder_dst.npy']  ]

                if self.is_training:
                    if self.options['true_face_power'] != 0:
                        self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
                        self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]

            elif 'liae' in archi:
                self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))

                self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
                self.inter_B  = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')

                inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
                inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
                inters_out_ch = inter_AB_out_ch+inter_B_out_ch
                self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')

                self.model_filename_list += [ [self.encoder,  'encoder.npy'],
                                              [self.inter_AB, 'inter_AB.npy'],
                                              [self.inter_B , 'inter_B.npy'],
                                              [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                if gan_power != 0:
                    self.D_src = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
                    self.D_dst = nn.PatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_dst")
                    self.model_filename_list += [ [self.D_src, 'D_src.npy'] ]
                    self.model_filename_list += [ [self.D_dst, 'D_dst.npy'] ]

                # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0
                self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
                if 'df' in archi:
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
                elif 'liae' in archi:
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()

                self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)

                if self.options['true_face_power'] != 0:
                    self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
                    self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]

                if gan_power != 0:
                    self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
                    self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights()+self.D_dst.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
                    self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_dst_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)


            # Compute losses per GPU
            gpu_src_latent_code_list = []
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gvs = []
            gpu_D_code_loss_gvs = []
            gpu_D_src_dst_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src      = self.warped_src [batch_slice,:,:,:]
                        gpu_src_code_in      = self.src_code_in[batch_slice,:]
                        gpu_warped_dst      = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src      = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst      = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:]
                        gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]
                        
                    # process model tensors
                    if 'df' in archi:
                        gpu_src_latent_code = self.inter.dense1(self.encoder(gpu_warped_src))
                        
                        gpu_src_in_code = self.inter.fd(gpu_src_code_in)
                        
                        gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                        gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                        
                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_in_code)
                        #gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)

                    elif 'liae' in archi:
                        gpu_src_code = self.encoder (gpu_warped_src)
                        gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
                        gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis  )
                        gpu_dst_code = self.encoder (gpu_warped_dst)
                        gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                        gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                        gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
                        gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )

                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)

                    gpu_src_latent_code_list.append(gpu_src_latent_code)
                    
                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
                    
                    # unpack masks from one combined mask
                    gpu_target_srcm      = tf.clip_by_value (gpu_target_srcm_all, 0, 1)                                   
                    gpu_target_dstm      = tf.clip_by_value (gpu_target_dstm_all, 0, 1)                    
                    gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)   
                    gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)

                    gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt  = gpu_target_dst_masked if masked_training else gpu_target_dst
                    
                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_src_loss =  tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
                    
                    if eyes_prio:
                        gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3])
                    
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )

                    face_style_power = self.options['face_style_power'] / 100.0
                    if face_style_power != 0 and not self.pretrain:
                        gpu_src_loss += nn.style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)

                    bg_style_power = self.options['bg_style_power'] / 100.0
                    if bg_style_power != 0 and not self.pretrain:
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )

                    gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    
                    if eyes_prio:
                        gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])
                    
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss

                    def DLoss(labels,logits):
                        return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])

                    if self.options['true_face_power'] != 0:
                        gpu_src_code_d = self.code_discriminator( gpu_src_code )
                        gpu_src_code_d_ones  = tf.ones_like (gpu_src_code_d)
                        gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
                        gpu_dst_code_d = self.code_discriminator( gpu_dst_code )
                        gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)

                        gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)

                        gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
                                           DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5

                        gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]

                    if gan_power != 0:
                        gpu_pred_src_src_d       = self.D_src(gpu_pred_src_src_masked_opt)
                        gpu_pred_src_src_d_ones  = tf.ones_like (gpu_pred_src_src_d)
                        gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)
                        gpu_target_src_d         = self.D_src(gpu_target_src_masked_opt)
                        gpu_target_src_d_ones    = tf.ones_like(gpu_target_src_d)
                        gpu_pred_dst_dst_d       = self.D_dst(gpu_pred_dst_dst_masked_opt)
                        gpu_pred_dst_dst_d_ones  = tf.ones_like (gpu_pred_dst_dst_d)
                        gpu_pred_dst_dst_d_zeros = tf.zeros_like(gpu_pred_dst_dst_d)
                        gpu_target_dst_d         = self.D_dst(gpu_target_dst_masked_opt)
                        gpu_target_dst_d_ones    = tf.ones_like(gpu_target_dst_d)

                        gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones   , gpu_target_src_d) + \
                                              DLoss(gpu_pred_src_src_d_zeros, gpu_pred_src_src_d) ) * 0.5 + \
                                             (DLoss(gpu_target_dst_d_ones   , gpu_target_dst_d) + \
                                              DLoss(gpu_pred_dst_dst_d_zeros, gpu_pred_dst_dst_d) ) * 0.5

                        gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights()+self.D_dst.get_weights() ) ]

                        gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + DLoss(gpu_pred_dst_dst_d_ones, gpu_pred_dst_dst_d))


                    gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                src_latent_code  = nn.concat(gpu_src_latent_code_list, 0)
                pred_src_src  = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

                src_loss = tf.concat(gpu_src_losses, 0)
                dst_loss = tf.concat(gpu_dst_losses, 0)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))

                if self.options['true_face_power'] != 0:
                    D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs))

                if gan_power != 0:
                    src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )


            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm_all, \
                              warped_dst, target_dst, target_dstm_all):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm_all:target_srcm_all,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm_all:target_dstm_all,
                                                       })
                return s, d
            self.src_dst_train = src_dst_train

            if self.options['true_face_power'] != 0:
                def D_train(warped_src, warped_dst):
                    nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
                self.D_train = D_train

            if gan_power != 0:
                def D_src_dst_train(warped_src, target_src, target_srcm_all, \
                                    warped_dst, target_dst, target_dstm_all):
                    nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
                                                                           self.target_src :target_src,
                                                                           self.target_srcm_all:target_srcm_all,
                                                                           self.warped_dst :warped_dst,
                                                                           self.target_dst :target_dst,
                                                                           self.target_dstm_all:target_dstm_all})
                self.D_src_dst_train = D_src_dst_train

            def AE_get_latent(warped_src):
                return nn.tf_sess.run ( src_latent_code, feed_dict={self.warped_src:warped_src})
            self.AE_get_latent = AE_get_latent
            
            def AE_view_src(warped_src, src_code_in):
                return nn.tf_sess.run ( pred_src_src,
                                            feed_dict={self.warped_src:warped_src, self.src_code_in:src_code_in })
            self.AE_view_src = AE_view_src
            
            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src,
                                                    self.warped_dst:warped_dst})
            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                if 'df' in archi:
                    gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

                elif 'liae' in archi:
                    gpu_dst_code = self.encoder (self.warped_dst)
                    gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                    gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                    gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
                    gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)

                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)

            
            def AE_merge( warped_dst):
                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if 'df' in archi:
                    if model == self.inter:
                        do_init = True
                elif 'liae' in archi:
                    if model == self.inter_AB:
                        do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            if ct_mode is not None:
                src_generators_count = int(src_generators_count * 1.5)

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.PITCH_YAW_ROLL, 'resolution': resolution},
                                                
                                              ],
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        generators_count=dst_generators_count )
                             ])
            
            self.last_src_samples_loss = []
            self.last_dst_samples_loss = []
            
            if self.pretrain_just_disabled:
                self.update_sample_for_preview(force_new=True)
                
        class PRD(nn.ModelBase):
            def on_build(self, ae_ch):
                self.dense1 = nn.Dense( ae_ch+1, 1024 )
                self.dense2 = nn.Dense( 1024, 2048 )
                self.dense3 = nn.Dense( 2048, 4096 )
                self.dense4 = nn.Dense( 4096, 4096 )
                self.dense5 = nn.Dense( 4096, ae_ch )
                
            def forward(self, inp, yaw_in):
                
                x = tf.concat( [inp, yaw_in], -1 )
                
                x = self.dense1(x)
                x = self.dense2(x)
                x = self.dense3(x)
                x = self.dense4(x)
                x = self.dense5(x)
                return x
                
        
        
       
        
        with tf.device( f'/GPU:0'):
            prd_model = PRD(256, name='PRD')
            
            prd_model.init_weights()
             
            prd_in = tf.placeholder (nn.floatx, (None,256) )
            prd_targ = tf.placeholder (nn.floatx, (None,256) )
            yaw_diff_in = tf.placeholder (nn.floatx, (None,1) )
            
            prd_out = prd_model(prd_in, yaw_diff_in)
            
            loss = tf.reduce_sum ( tf.abs (prd_out - prd_targ) )
            
            loss_gvs = nn.gradients (loss, prd_model.get_weights() )
            
            prd_opt = nn.RMSprop(lr=5e-6, lr_dropout=0.3, name='prd_opt')
            prd_opt.initialize_variables(prd_model.get_weights())
            prd_opt.init_weights()
            
            loss_gv_op = prd_opt.get_update_op (loss_gvs)
            
            
                
        s_gen, _ = self.get_training_data_generators()
        bs = self.get_batch_size()
        
        for n in range(1000):
            warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next()
        
            
        
            sl = self.AE_get_latent(target_src)
            
            prd_in_np = []
            prd_targ_np = []
            yaw_diff_in_np = []
            for i in range(bs):
                prd_in_np += [sl[i]]
                
                j = i
                while j == i:
                    j = np.random.randint(bs)
                
                prd_targ_np += [ sl[j] ]
                
                yaw_diff_in_np += [ np.float32( [ src_pyr[j][1]-src_pyr[i][1] ] ) ]
                
            prd_loss, _ = nn.tf_sess.run([loss, loss_gv_op], feed_dict={prd_in:prd_in_np, prd_targ:prd_targ_np, yaw_diff_in:yaw_diff_in_np} )
            print(f'{n} loss = {prd_loss}')
        
        
        warped_src, target_src, target_srcm_all, src_pyr = s_gen.generate_next()
        sl = self.AE_get_latent(target_src)
        
        yaw_diff_in_np = np.float32( [ [-0.4] ] *bs )
        
        new_sl = nn.tf_sess.run(prd_out, feed_dict={prd_in:sl, yaw_diff_in:yaw_diff_in_np} )
                
        new_target_src = self.AE_view_src( target_src, new_sl )            
        
        target_src = np.clip( nn.to_data_format( target_src ,"NHWC", self.model_data_format), 0.0, 1.0)
        new_target_src = np.clip( nn.to_data_format( new_target_src ,"NHWC", self.model_data_format), 0.0, 1.0)
        for i in range(bs):
            
            screen = np.concatenate ( (target_src[i], new_target_src[i]), 1 )
            cv2.imshow("", (screen*255).astype(np.uint8) )
            cv2.waitKey(0)
            
        import code
        code.interact(local=dict(globals(), **locals()))    
예제 #18
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW" if len(
            devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        resolution = self.resolution = 96
        self.face_type = FaceType.FULL
        ae_dims = 128
        e_dims = 128
        d_dims = 64
        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) >= 1 and all(
            [dev.total_mem_gb >= 4 for dev in devices])
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_ch = 3
        bgr_shape = nn.get4Dshape(resolution, resolution, input_ch)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        self.model_filename_list = []

        model_archi = nn.DeepFakeArchi(resolution, mod='quick')

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder(nn.floatx, bgr_shape)
            self.warped_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_src = tf.placeholder(nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder(nn.floatx, bgr_shape)

            self.target_srcm = tf.placeholder(nn.floatx, mask_shape)
            self.target_dstm = tf.placeholder(nn.floatx, mask_shape)

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = model_archi.Encoder(in_ch=input_ch,
                                               e_ch=e_dims,
                                               name='encoder')
            encoder_out_ch = self.encoder.compute_output_channels(
                (nn.floatx, bgr_shape))

            self.inter = model_archi.Inter(in_ch=encoder_out_ch,
                                           ae_ch=ae_dims,
                                           ae_out_ch=ae_dims,
                                           d_ch=d_dims,
                                           name='inter')
            inter_out_ch = self.inter.compute_output_channels(
                (nn.floatx, (None, encoder_out_ch)))

            self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch,
                                                   d_ch=d_dims,
                                                   name='decoder_src')
            self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch,
                                                   d_ch=d_dims,
                                                   name='decoder_dst')

            self.model_filename_list += [[self.encoder, 'encoder.npy'],
                                         [self.inter, 'inter.npy'],
                                         [self.decoder_src, 'decoder_src.npy'],
                                         [self.decoder_dst, 'decoder_dst.npy']]

            if self.is_training:
                self.src_dst_trainable_weights = self.encoder.get_weights(
                ) + self.inter.get_weights() + self.decoder_src.get_weights(
                ) + self.decoder_dst.get_weights()

                # Initialize optimizers
                self.src_dst_opt = nn.RMSprop(lr=2e-4,
                                              lr_dropout=0.3,
                                              name='src_dst_opt')
                self.src_dst_opt.initialize_variables(
                    self.src_dst_trainable_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, 4 // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):
                    batch_slice = slice(gpu_id * bs_per_gpu,
                                        (gpu_id + 1) * bs_per_gpu)
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_warped_dst = self.warped_dst[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]
                        gpu_target_srcm = self.target_srcm[
                            batch_slice, :, :, :]
                        gpu_target_dstm = self.target_dstm[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_src_code = self.inter(self.encoder(gpu_warped_src))
                    gpu_dst_code = self.inter(self.encoder(gpu_warped_dst))
                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(
                        gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(
                        gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                        gpu_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.gaussian_blur(
                        gpu_target_srcm, max(1, resolution // 32))
                    gpu_target_dstm_blur = nn.gaussian_blur(
                        gpu_target_dstm, max(1, resolution // 32))

                    gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_src_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_src_masked_opt,
                                      gpu_pred_src_src_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_src_masked_opt -
                                       gpu_pred_src_src_masked_opt),
                        axis=[1, 2, 3])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_srcm - gpu_pred_src_srcm),
                        axis=[1, 2, 3])

                    gpu_dst_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_dst_masked_opt,
                                      gpu_pred_dst_dst_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dst_masked_opt -
                                       gpu_pred_dst_dst_masked_opt),
                        axis=[1, 2, 3])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dstm - gpu_pred_dst_dstm),
                        axis=[1, 2, 3])

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [
                        nn.gradients(gpu_G_loss,
                                     self.src_dst_trainable_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred_src_src = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

                src_loss = nn.average_tensor_list(gpu_src_losses)
                dst_loss = nn.average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op(
                    src_dst_loss_gv)

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run(
                    [src_loss, dst_loss, src_dst_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.target_srcm: target_srcm,
                        self.warped_dst: warped_dst,
                        self.target_dst: target_dst,
                        self.target_dstm: target_dstm,
                    })
                s = np.mean(s)
                d = np.mean(d)
                return s, d

            self.src_dst_train = src_dst_train

            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run([
                    pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst,
                    pred_src_dstm
                ],
                                      feed_dict={
                                          self.warped_src: warped_src,
                                          self.warped_dst: warped_dst
                                      })

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge(warped_dst):

                return nn.tf_sess.run(
                    [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm],
                    feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init and self.pretrained_model_path is not None:
                pretrained_filepath = self.pretrained_model_path / filename
                if pretrained_filepath.exists():
                    do_init = not model.load_weights(pretrained_filepath)

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'sample_type':
                        SampleProcessor.SampleType.FACE_IMAGE,
                        'warp':
                        True,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.BGR,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'sample_type':
                        SampleProcessor.SampleType.FACE_IMAGE,
                        'warp':
                        False,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.BGR,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'sample_type':
                        SampleProcessor.SampleType.FACE_MASK,
                        'warp':
                        False,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.G,
                        'face_mask_type':
                        SampleProcessor.FaceMaskType.FULL_FACE,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'sample_type':
                        SampleProcessor.SampleType.FACE_IMAGE,
                        'warp':
                        True,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.BGR,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'sample_type':
                        SampleProcessor.SampleType.FACE_IMAGE,
                        'warp':
                        False,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.BGR,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }, {
                        'sample_type':
                        SampleProcessor.SampleType.FACE_MASK,
                        'warp':
                        False,
                        'transform':
                        True,
                        'channel_type':
                        SampleProcessor.ChannelType.G,
                        'face_mask_type':
                        SampleProcessor.FaceMaskType.FULL_FACE,
                        'face_type':
                        self.face_type,
                        'data_format':
                        nn.data_format,
                        'resolution':
                        resolution
                    }],
                    generators_count=dst_generators_count)
            ])

            self.last_samples = None
예제 #19
0
파일: Model.py 프로젝트: rumax/DeepFaceLab
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        nn.initialize(data_format="NHWC")
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256  #self.options['resolution']
        #self.face_type = {'h'  : FaceType.HALF,
        #                  'mf' : FaceType.MID_FULL,
        #                  'f'  : FaceType.FULL,
        #                  'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]
        self.face_type = FaceType.FULL

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = TernausNet(f'{self.model_name}_FANSeg',
                                resolution,
                                FaceType.toString(self.face_type),
                                load_weights=not self.is_first_run(),
                                weights_file_root=self.get_model_root_path(),
                                training=True,
                                place_model_on_cpu=place_model_on_cpu)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.net(
                        [gpu_input_t])
                    gpu_pred_list.append(gpu_pred_t)

                    gpu_loss = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=gpu_target_t, logits=gpu_pred_logits_t),
                        axis=[1, 2, 3])
                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.tf_gradients(gpu_loss, self.model.net_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred = nn.tf_concat(gpu_pred_list, 0)
                loss = tf.reduce_mean(gpu_losses)

                loss_gv_op = self.model.opt.get_update_op(
                    nn.tf_average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            def train(input_np, target_np):
                l, _ = nn.tf_sess.run([loss, loss_gv_op],
                                      feed_dict={
                                          self.model.input_t: input_np,
                                          self.model.target_t: target_np
                                      })
                return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            training_data_src_path = self.training_data_src_path
            training_data_dst_path = self.training_data_dst_path

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            src_generators_count = int(src_generators_count * 1.5)

            src_generator = SampleGeneratorFace(
                training_data_src_path,
                random_ct_samples_path=training_data_src_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=True),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'ct_mode': 'lct',
                        'warp': True,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'face_type': self.face_type,
                        'motion_blur': (25, 5),
                        'gaussian_blur': (25, 5),
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_MASK,
                        'warp': True,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.G,
                        'face_mask_type':
                        SampleProcessor.FaceMaskType.FULL_FACE,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=src_generators_count)

            dst_generator = SampleGeneratorFace(
                training_data_dst_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=True),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': True,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'face_type': self.face_type,
                        'motion_blur': (25, 5),
                        'gaussian_blur': (25, 5),
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=dst_generators_count,
                raise_on_no_data=False)
            if not dst_generator.is_initialized():
                io.log_info(
                    f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n"
                )

            self.set_training_data_generators([src_generator, dst_generator])
예제 #20
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug()  else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        conv_kernel_initializer = nn.initializers.ca()

        class Downscale(nn.ModelBase):
            def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs ):
                self.conv1 = nn.Conv2D( self.in_ch,
                                          self.out_ch // (4 if self.subpixel else 1),
                                          kernel_size=self.kernel_size,
                                          strides=1 if self.subpixel else 2,
                                          padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer )

            def forward(self, x):
                x = self.conv1(x)

                if self.subpixel:
                    x = nn.tf_space_to_depth(x, 2)

                if self.use_activator:
                    x = nn.tf_gelu(x)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
                self.downs = []

                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch*( min(2**i, 8)  )
                    self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
                    last_ch = self.downs[-1].get_out_ch()

            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = nn.tf_gelu(x)
                x = nn.tf_depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = nn.tf_gelu(x)
                x = self.conv2(x)
                x = inp + x
                x = nn.tf_gelu(x)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):
                self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
            def forward(self, inp):
                return nn.tf_flatten(self.down1(inp))

        class Inter(nn.ModelBase):
            def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch, **kwargs):
                self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch
                super().__init__(**kwargs)

            def on_build(self):
                in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch

                self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
                self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=4, kernel_initializer=tf.initializers.orthogonal )
                self.upscale1 = Upscale(ae_out_ch, d_ch*8)
                self.res1 = ResidualBlock(d_ch*8)

            def forward(self, inp):
                x = self.dense1(inp)
                x = self.dense2(x)
                x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
                x = self.upscale1(x)
                x = self.res1(x)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch):
                self.upscale1 = Upscale(in_ch, d_ch*4)
                self.res1     = ResidualBlock(d_ch*4)
                self.upscale2 = Upscale(d_ch*4, d_ch*2)
                self.res2     = ResidualBlock(d_ch*2)
                self.upscale3 = Upscale(d_ch*2, d_ch*1)
                self.res3     = ResidualBlock(d_ch*1)

                self.upscalem1 = Upscale(in_ch, d_ch)
                self.upscalem2 = Upscale(d_ch, d_ch//2)
                self.upscalem3 = Upscale(d_ch//2, d_ch//2)

                self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
                self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)

            def forward(self, inp):
                z = inp
                x = self.upscale1 (z)
                x = self.res1     (x)
                x = self.upscale2 (x)
                x = self.res2     (x)
                x = self.upscale3 (x)
                x = self.res3     (x)

                y = self.upscalem1 (z)
                y = self.upscalem2 (y)
                y = self.upscalem3 (y)

                return tf.nn.sigmoid(self.out_conv(x)), \
                       tf.nn.sigmoid(self.out_convm(y))

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        resolution = self.resolution = 96
        ae_dims = 128
        e_dims = 128
        d_dims = 64
        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 2 for dev in devices])
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch = 3
        output_ch = 3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        lowest_dense_res = resolution // 16

        self.model_filename_list = []


        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
            self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)

            self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
            self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)

        # Initializing model classes
        with tf.device (models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
            encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))

            self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, d_ch=d_dims, name='inter')
            inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))

            self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_src')
            self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, name='decoder_dst')

            self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                          [self.inter,       'inter.npy'      ],
                                          [self.decoder_src, 'decoder_src.npy'],
                                          [self.decoder_dst, 'decoder_dst.npy']  ]

            if self.is_training:
                self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()

                # Initialize optimizers
                self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt')
                self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu )
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, 4 // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []
            
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
                    batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src   = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst   = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src   = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst   = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm  = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_dstm  = self.target_dstm[batch_slice,:,:,:]

                    # process model tensors
                    gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                    gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur)

                    gpu_src_loss =  tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )

                    gpu_dst_loss  = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [ nn.tf_gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                pred_src_src  = nn.tf_concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.tf_concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.tf_concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)

                src_loss = nn.tf_average_tensor_list(gpu_src_losses)
                dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
                src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv)

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       })
                s = np.mean(s)
                d = np.mean(d)
                return s, d
            self.src_dst_train = src_dst_train

            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src,
                                                    self.warped_dst:warped_dst})

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge( warped_dst):

                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

            if do_init and self.pretrained_model_path is not None:
                pretrained_filepath = self.pretrained_model_path / filename
                if pretrained_filepath.exists():
                    do_init = not model.load_weights(pretrained_filepath)

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
                        output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution, },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),        'data_format':nn.data_format, 'resolution': resolution, },
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),          'data_format':nn.data_format, 'resolution': resolution } ],
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False),
                        output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution':resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),        'data_format':nn.data_format, 'resolution': resolution},
                                                {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M),          'data_format':nn.data_format, 'resolution': resolution} ],
                        generators_count=dst_generators_count )
                             ])

            self.last_samples = None
예제 #21
0
    def on_initialize_options(self):
        device_config = nn.getCurrentDeviceConfig()

        lowest_vram = 2
        if len(device_config.devices) != 0:
            lowest_vram = device_config.devices.get_worst_device().total_mem_gb

        if lowest_vram >= 4:
            suggest_batch_size = 8
        else:
            suggest_batch_size = 4

        yn_str = {True:'y',False:'n'}

        default_resolution         = self.options['resolution']         = self.load_or_def_option('resolution', 128)
        default_face_type          = self.options['face_type']          = self.load_or_def_option('face_type', 'f')
        default_models_opt_on_gpu  = self.options['models_opt_on_gpu']  = self.load_or_def_option('models_opt_on_gpu', True)
        default_archi              = self.options['archi']              = self.load_or_def_option('archi', 'df')
        default_ae_dims            = self.options['ae_dims']            = self.load_or_def_option('ae_dims', 256)
        default_e_dims             = self.options['e_dims']             = self.load_or_def_option('e_dims', 64)
        default_d_dims             = self.options['d_dims']             = self.options.get('d_dims', None)
        default_d_mask_dims        = self.options['d_mask_dims']        = self.options.get('d_mask_dims', None)
        default_masked_training    = self.options['masked_training']    = self.load_or_def_option('masked_training', True)
        default_eyes_prio          = self.options['eyes_prio']          = self.load_or_def_option('eyes_prio', False)
        default_lr_dropout         = self.options['lr_dropout']         = self.load_or_def_option('lr_dropout', False)
        default_random_warp        = self.options['random_warp']        = self.load_or_def_option('random_warp', True)
        default_gan_power          = self.options['gan_power']          = self.load_or_def_option('gan_power', 0.0)
        default_true_face_power    = self.options['true_face_power']    = self.load_or_def_option('true_face_power', 0.0)
        default_face_style_power   = self.options['face_style_power']   = self.load_or_def_option('face_style_power', 0.0)
        default_bg_style_power     = self.options['bg_style_power']     = self.load_or_def_option('bg_style_power', 0.0)
        default_ct_mode            = self.options['ct_mode']            = self.load_or_def_option('ct_mode', 'none')
        default_clipgrad           = self.options['clipgrad']           = self.load_or_def_option('clipgrad', False)
        default_pretrain           = self.options['pretrain']           = self.load_or_def_option('pretrain', False)

        ask_override = self.ask_override()
        if self.is_first_run() or ask_override:
            self.ask_autobackup_hour()
            self.ask_write_preview_history()
            self.ask_target_iter()
            self.ask_random_flip()
            self.ask_batch_size(suggest_batch_size)

        if self.is_first_run():
            resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
            resolution = np.clip ( (resolution // 16) * 16, 64, 512)
            self.options['resolution'] = resolution
            self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf'], help_message="Half / mid face / full face / whole face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead, but requires manual merge in Adobe After Effects.").lower()
            self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower()

        default_d_dims             = 48 if self.options['archi'] == 'dfhd' else 64
        default_d_dims             = self.options['d_dims']             = self.load_or_def_option('d_dims', default_d_dims)

        default_d_mask_dims        = default_d_dims // 3
        default_d_mask_dims        += default_d_mask_dims % 2
        default_d_mask_dims        = self.options['d_mask_dims']        = self.load_or_def_option('d_mask_dims', default_d_mask_dims)

        if self.is_first_run():
            self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )

            e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['e_dims'] = e_dims + e_dims % 2


            d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
            self.options['d_dims'] = d_dims + d_dims % 2

            d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
            self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2

        if self.is_first_run() or ask_override:
            if self.options['face_type'] == 'wf':
                self.options['masked_training']  = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly.  When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.")
            
            self.options['eyes_prio']  = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ')
      
        if self.is_first_run() or ask_override:
            self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")

            self.options['lr_dropout']  = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations.")
            self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")

            self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 10.0", help_message="Train the network in Generative Adversarial manner. Accelerates the speed of training. Forces the neural network to learn small details of the face. You can enable/disable this option at any time. Typical value is 1.0"), 0.0, 10.0 )

            if 'df' in self.options['archi']:
                self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
            else:
                self.options['true_face_power'] = 0.0

            if self.options['face_type'] != 'wf':
                self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
                self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 )
                
            self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
            self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
            
            self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")

        if self.options['pretrain'] and self.get_pretraining_data_path() is None:
            raise Exception("pretraining_data_path is not defined")

        self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
예제 #22
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW" if len(
            devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        resolution = self.resolution = 96
        self.face_type = FaceType.FULL
        ae_dims = 256
        e_dims = 64
        d_dims = 64
        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) >= 1 and all(
            [dev.total_mem_gb >= 4 for dev in devices])
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_ch = 3
        bgr_shape = nn.get4Dshape(resolution, resolution, input_ch)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        self.model_filename_list = []

        kernel_initializer = tf.initializers.glorot_uniform()

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3):
                self.conv1 = nn.Conv2D(in_ch,
                                       out_ch * 4,
                                       kernel_size=kernel_size,
                                       padding='SAME',
                                       kernel_initializer=kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3):
                self.conv1 = nn.Conv2D(ch,
                                       ch,
                                       kernel_size=kernel_size,
                                       padding='SAME',
                                       kernel_initializer=kernel_initializer)
                self.conv2 = nn.Conv2D(ch,
                                       ch,
                                       kernel_size=kernel_size,
                                       padding='SAME',
                                       kernel_initializer=kernel_initializer)

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp + x, 0.2)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):

                self.down11 = nn.Conv2D(in_ch,
                                        e_ch,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down12 = nn.Conv2D(e_ch,
                                        e_ch,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)

                self.down21 = nn.Conv2D(e_ch,
                                        e_ch * 2,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down22 = nn.Conv2D(e_ch * 2,
                                        e_ch * 2,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)

                self.down31 = nn.Conv2D(e_ch * 2,
                                        e_ch * 4,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down32 = nn.Conv2D(e_ch * 4,
                                        e_ch * 4,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down33 = nn.Conv2D(e_ch * 4,
                                        e_ch * 4,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)

                self.down41 = nn.Conv2D(e_ch * 4,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down42 = nn.Conv2D(e_ch * 8,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down43 = nn.Conv2D(e_ch * 8,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)

                self.down51 = nn.Conv2D(e_ch * 8,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down52 = nn.Conv2D(e_ch * 8,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)
                self.down53 = nn.Conv2D(e_ch * 8,
                                        e_ch * 8,
                                        kernel_size=3,
                                        strides=1,
                                        padding='SAME',
                                        kernel_initializer=kernel_initializer)

            def forward(self, inp):
                x = inp

                x = self.down11(x)
                x = self.down12(x)
                x = nn.max_pool(x)

                x = self.down21(x)
                x = self.down22(x)
                x = nn.max_pool(x)

                x = self.down31(x)
                x = self.down32(x)
                x = self.down33(x)
                x = nn.max_pool(x)

                x = self.down41(x)
                x = self.down42(x)
                x = self.down43(x)
                x = nn.max_pool(x)

                x = self.down51(x)
                x = self.down52(x)
                x = self.down53(x)
                x = nn.max_pool(x)

                x = nn.flatten(x)
                return x

        class Downscale(nn.ModelBase):
            def __init__(self,
                         in_ch,
                         out_ch,
                         kernel_size=5,
                         dilations=1,
                         subpixel=True,
                         use_activator=True,
                         *kwargs):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                self.dilations = dilations
                self.subpixel = subpixel
                self.use_activator = use_activator
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs):
                self.conv1 = nn.Conv2D(self.in_ch,
                                       self.out_ch //
                                       (4 if self.subpixel else 1),
                                       kernel_size=self.kernel_size,
                                       strides=1 if self.subpixel else 2,
                                       padding='SAME',
                                       dilations=self.dilations,
                                       kernel_initializer=kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                if self.subpixel:
                    x = nn.space_to_depth(x, 2)
                if self.use_activator:
                    x = tf.nn.leaky_relu(x, 0.1)
                return x

            def get_out_ch(self):
                return (self.out_ch // 4) * 4

        class DownscaleBlock(nn.ModelBase):
            def on_build(self,
                         in_ch,
                         ch,
                         n_downscales,
                         kernel_size,
                         dilations=1,
                         subpixel=True):
                self.downs = []

                last_ch = in_ch
                for i in range(n_downscales):
                    cur_ch = ch * (min(2**i, 8))
                    self.downs.append(
                        Downscale(last_ch,
                                  cur_ch,
                                  kernel_size=kernel_size,
                                  dilations=dilations,
                                  subpixel=subpixel))
                    last_ch = self.downs[-1].get_out_ch()

            def forward(self, inp):
                x = inp
                for down in self.downs:
                    x = down(x)
                return x

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3):
                self.conv1 = nn.Conv2D(in_ch,
                                       out_ch * 4,
                                       kernel_size=kernel_size,
                                       padding='SAME',
                                       kernel_initializer=kernel_initializer)

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.depth_to_space(x, 2)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch):
                self.down1 = DownscaleBlock(in_ch,
                                            e_ch,
                                            n_downscales=4,
                                            kernel_size=5,
                                            dilations=1,
                                            subpixel=False)

            def forward(self, inp):
                x = nn.flatten(self.down1(inp))
                return x

        class Branch(nn.ModelBase):
            def on_build(self, in_ch, ae_ch):
                self.dense1 = nn.Dense(in_ch, ae_ch)

            def forward(self, inp):
                x = self.dense1(inp)
                return x

        class Classifier(nn.ModelBase):
            def on_build(self, in_ch, n_classes):
                self.dense1 = nn.Dense(in_ch, 4096)
                self.dense2 = nn.Dense(4096, 4096)
                self.pitch_dense = nn.Dense(4096, n_classes)
                self.yaw_dense = nn.Dense(4096, n_classes)

            def forward(self, inp):
                x = inp
                x = self.dense1(x)
                x = self.dense2(x)
                return self.pitch_dense(x), self.yaw_dense(x)

        lowest_dense_res = resolution // 16

        class Inter(nn.ModelBase):
            def on_build(self, in_ch, ae_out_ch):
                self.ae_out_ch = ae_out_ch

                self.dense2 = nn.Dense(
                    in_ch, lowest_dense_res * lowest_dense_res * ae_out_ch)
                self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

            def forward(self, inp):
                x = inp
                x = self.dense2(x)
                x = nn.reshape_4D(x, lowest_dense_res, lowest_dense_res,
                                  self.ae_out_ch)
                x = self.upscale1(x)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch, d_mask_ch):

                self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3)
                self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3)
                self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3)

                self.res0 = ResidualBlock(d_ch * 8, kernel_size=3)
                self.res1 = ResidualBlock(d_ch * 4, kernel_size=3)
                self.res2 = ResidualBlock(d_ch * 2, kernel_size=3)

                self.out_conv = nn.Conv2D(
                    d_ch * 2,
                    3,
                    kernel_size=1,
                    padding='SAME',
                    kernel_initializer=kernel_initializer)

            def forward(self, inp):
                z = inp

                x = self.upscale0(z)
                x = self.res0(x)
                x = self.upscale1(x)
                x = self.res1(x)
                x = self.upscale2(x)
                x = self.res2(x)

                return tf.nn.sigmoid(self.out_conv(x))

        n_pyr_degs = self.n_pyr_degs = 3
        n_pyr_classes = self.n_pyr_classes = 180 // self.n_pyr_degs

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder(nn.floatx, bgr_shape)
            self.target_src = tf.placeholder(nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder(nn.floatx, bgr_shape)
            self.pitches_vector = tf.placeholder(nn.floatx,
                                                 (None, n_pyr_classes))
            self.yaws_vector = tf.placeholder(nn.floatx, (None, n_pyr_classes))

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
            encoder_out_ch = self.encoder.compute_output_channels(
                (nn.floatx, bgr_shape))

            self.bT = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bT')
            self.bP = Branch(in_ch=encoder_out_ch, ae_ch=ae_dims, name='bP')

            self.bTC = Classifier(in_ch=ae_dims,
                                  n_classes=self.n_pyr_classes,
                                  name='bTC')
            self.bPC = Classifier(in_ch=ae_dims,
                                  n_classes=self.n_pyr_classes,
                                  name='bPC')

            self.inter = Inter(in_ch=ae_dims * 2,
                               ae_out_ch=ae_dims * 2,
                               name='inter')

            self.decoder = Decoder(in_ch=ae_dims * 2,
                                   d_ch=d_dims,
                                   d_mask_ch=d_dims,
                                   name='decoder')

            self.model_filename_list += [[self.encoder, 'encoder.npy'],
                                         [self.bT, 'bT.npy'],
                                         [self.bTC, 'bTC.npy'],
                                         [self.bP, 'bP.npy'],
                                         [self.bPC, 'bPC.npy'],
                                         [self.inter, 'inter.npy'],
                                         [self.decoder, 'decoder.npy']]

            if self.is_training:
                self.all_trainable_weights = self.encoder.get_weights() + \
                                             self.bT.get_weights() +\
                                             self.bTC.get_weights() +\
                                             self.bP.get_weights() +\
                                             self.bPC.get_weights() +\
                                             self.inter.get_weights() +\
                                             self.decoder.get_weights()

                # Initialize optimizers
                self.src_dst_opt = nn.RMSprop(lr=5e-5, name='src_dst_opt')
                self.src_dst_opt.initialize_variables(
                    self.all_trainable_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, 32 // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_list = []
            gpu_pred_dst_list = []

            gpu_A_losses = []
            gpu_B_losses = []
            gpu_C_losses = []
            gpu_D_losses = []
            gpu_A_loss_gvs = []
            gpu_B_loss_gvs = []
            gpu_C_loss_gvs = []
            gpu_D_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):
                    batch_slice = slice(gpu_id * bs_per_gpu,
                                        (gpu_id + 1) * bs_per_gpu)
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]

                        gpu_pitches_vector = self.pitches_vector[
                            batch_slice, :]
                        gpu_yaws_vector = self.yaws_vector[batch_slice, :]

                    # process model tensors
                    gpu_src_enc_code = self.encoder(gpu_warped_src)
                    gpu_dst_enc_code = self.encoder(gpu_target_dst)

                    gpu_src_bT_code = self.bT(gpu_src_enc_code)
                    gpu_src_bT_code_ng = tf.stop_gradient(gpu_src_bT_code)

                    gpu_src_T_pitch, gpu_src_T_yaw = self.bTC(gpu_src_bT_code)

                    gpu_dst_bT_code = self.bT(gpu_dst_enc_code)

                    gpu_src_bP_code = self.bP(gpu_src_enc_code)

                    gpu_src_P_pitch, gpu_src_P_yaw = self.bPC(gpu_src_bP_code)

                    def crossentropy(target, output):
                        output = tf.nn.softmax(output)
                        output = tf.clip_by_value(output, 1e-7, 1 - 1e-7)
                        return tf.reduce_sum(target * -tf.log(output),
                                             axis=-1,
                                             keepdims=False)

                    def negative_crossentropy(n_classes, output):
                        output = tf.nn.softmax(output)
                        output = tf.clip_by_value(output, 1e-7, 1 - 1e-7)
                        return (1.0 / n_classes) * tf.reduce_sum(
                            tf.log(output), axis=-1, keepdims=False)

                    gpu_src_bT_code_n = gpu_src_bT_code_ng + tf.random.normal(
                        tf.shape(gpu_src_bT_code_ng))
                    gpu_src_bP_code_n = gpu_src_bP_code + tf.random.normal(
                        tf.shape(gpu_src_bP_code))

                    gpu_pred_src = self.decoder(
                        self.inter(
                            tf.concat([gpu_src_bT_code_ng, gpu_src_bP_code],
                                      axis=-1)))
                    gpu_pred_src_n = self.decoder(
                        self.inter(
                            tf.concat([gpu_src_bT_code_n, gpu_src_bP_code_n],
                                      axis=-1)))
                    gpu_pred_dst = self.decoder(
                        self.inter(
                            tf.concat([gpu_dst_bT_code, gpu_src_bP_code],
                                      axis=-1)))

                    gpu_A_loss  = 1.0*crossentropy(gpu_pitches_vector, gpu_src_T_pitch ) + \
                                  1.0*crossentropy(gpu_yaws_vector,    gpu_src_T_yaw )


                    gpu_B_loss = 0.1*crossentropy(gpu_pitches_vector, gpu_src_P_pitch ) + \
                                 0.1*crossentropy(gpu_yaws_vector, gpu_src_P_yaw )

                    gpu_C_loss = 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_pitch ) + \
                                 0.1*negative_crossentropy( n_pyr_classes, gpu_src_P_yaw )


                    gpu_D_loss = 0.0000001*(\
                                    0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src), axis=[1,2,3]) + \
                                    0.5*tf.reduce_sum(tf.square(gpu_target_src-gpu_pred_src_n), axis=[1,2,3]) )

                    gpu_pred_src_list.append(gpu_pred_src)
                    gpu_pred_dst_list.append(gpu_pred_dst)

                    gpu_A_losses += [gpu_A_loss]
                    gpu_B_losses += [gpu_B_loss]
                    gpu_C_losses += [gpu_C_loss]
                    gpu_D_losses += [gpu_D_loss]

                    A_weights = self.encoder.get_weights(
                    ) + self.bT.get_weights() + self.bTC.get_weights()
                    B_weights = self.bPC.get_weights()
                    C_weights = self.encoder.get_weights(
                    ) + self.bP.get_weights()
                    D_weights = self.inter.get_weights(
                    ) + self.decoder.get_weights()

                    gpu_A_loss_gvs += [nn.gradients(gpu_A_loss, A_weights)]
                    gpu_B_loss_gvs += [nn.gradients(gpu_B_loss, B_weights)]
                    gpu_C_loss_gvs += [nn.gradients(gpu_C_loss, C_weights)]
                    gpu_D_loss_gvs += [nn.gradients(gpu_D_loss, D_weights)]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred_src = nn.concat(gpu_pred_src_list, 0)
                pred_dst = nn.concat(gpu_pred_dst_list, 0)
                A_loss = nn.average_tensor_list(gpu_A_losses)
                B_loss = nn.average_tensor_list(gpu_B_losses)
                C_loss = nn.average_tensor_list(gpu_C_losses)
                D_loss = nn.average_tensor_list(gpu_D_losses)

                A_loss_gv = nn.average_gv_list(gpu_A_loss_gvs)
                B_loss_gv = nn.average_gv_list(gpu_B_loss_gvs)
                C_loss_gv = nn.average_gv_list(gpu_C_loss_gvs)
                D_loss_gv = nn.average_gv_list(gpu_D_loss_gvs)
                A_loss_gv_op = self.src_dst_opt.get_update_op(A_loss_gv)
                B_loss_gv_op = self.src_dst_opt.get_update_op(B_loss_gv)
                C_loss_gv_op = self.src_dst_opt.get_update_op(C_loss_gv)
                D_loss_gv_op = self.src_dst_opt.get_update_op(D_loss_gv)

            # Initializing training and view functions
            def A_train(warped_src, target_src, pitches_vector, yaws_vector):
                l, _ = nn.tf_sess.run(
                    [A_loss, A_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.pitches_vector: pitches_vector,
                        self.yaws_vector: yaws_vector
                    })
                return np.mean(l)

            self.A_train = A_train

            def B_train(warped_src, target_src, pitches_vector, yaws_vector):
                l, _ = nn.tf_sess.run(
                    [B_loss, B_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.pitches_vector: pitches_vector,
                        self.yaws_vector: yaws_vector
                    })
                return np.mean(l)

            self.B_train = B_train

            def C_train(warped_src, target_src, pitches_vector, yaws_vector):
                l, _ = nn.tf_sess.run(
                    [C_loss, C_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.pitches_vector: pitches_vector,
                        self.yaws_vector: yaws_vector
                    })
                return np.mean(l)

            self.C_train = C_train

            def D_train(warped_src, target_src, pitches_vector, yaws_vector):
                l, _ = nn.tf_sess.run(
                    [D_loss, D_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.pitches_vector: pitches_vector,
                        self.yaws_vector: yaws_vector
                    })
                return np.mean(l)

            self.D_train = D_train

            def AE_view(warped_src):
                return nn.tf_sess.run([pred_src],
                                      feed_dict={self.warped_src: warped_src})

            self.AE_view = AE_view

            def AE_view2(warped_src, target_dst):
                return nn.tf_sess.run([pred_dst],
                                      feed_dict={
                                          self.warped_src: warped_src,
                                          self.target_dst: target_dst
                                      })

            self.AE_view2 = AE_view2
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge(warped_dst):

                return nn.tf_sess.run(
                    [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm],
                    feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter:
                    do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init and self.pretrained_model_path is not None:
                pretrained_filepath = self.pretrained_model_path / filename
                if pretrained_filepath.exists():
                    do_init = not model.load_weights(pretrained_filepath)

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID,
                            'resolution': resolution
                        },
                    ],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.PITCH_YAW_ROLL_SIGMOID,
                            'resolution': resolution
                        },
                    ],
                    generators_count=dst_generators_count)
            ])

            self.last_samples = None
예제 #23
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if self.is_exporting or (len(
            device_config.devices) != 0 and not self.is_debug()) else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256

        self.face_type = {
            'h': FaceType.HALF,
            'mf': FaceType.MID_FULL,
            'f': FaceType.FULL,
            'wf': FaceType.WHOLE_FACE,
            'head': FaceType.HEAD
        }[self.options['face_type']]

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name='XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}'
                               if len(devices) != 0 else f'/CPU:0'):
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t, pretrain=self.pretrain)
                    gpu_pred_list.append(gpu_pred_t)

                    if self.pretrain:
                        # Structural loss
                        gpu_loss = tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 11.6)),
                            axis=[1])
                        gpu_loss += tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 23.2)),
                            axis=[1])
                        # Pixel loss
                        gpu_loss += tf.reduce_mean(
                            10 * tf.square(gpu_target_t - gpu_pred_t),
                            axis=[1, 2, 3])
                    else:
                        gpu_loss = tf.reduce_mean(
                            tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=gpu_target_t, logits=gpu_pred_logits_t),
                            axis=[1, 2, 3])

                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
            with tf.device(models_opt_device):
                pred = tf.concat(gpu_pred_list, 0)
                loss = tf.concat(gpu_losses, 0)
                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            if self.pretrain:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l
            else:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            if self.pretrain:
                pretrain_gen = SampleGeneratorFace(
                    self.get_pretraining_data_path(),
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.G,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    uniform_yaw_distribution=False,
                    generators_count=cpu_count)
                self.set_training_data_generators([pretrain_gen])
            else:
                srcdst_generator = SampleGeneratorFaceXSeg(
                    [self.training_data_src_path, self.training_data_dst_path],
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    resolution=resolution,
                    face_type=self.face_type,
                    generators_count=src_dst_generators_count,
                    data_format=nn.data_format)

                src_generator = SampleGeneratorFace(
                    self.training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=src_generators_count,
                    raise_on_no_data=False)
                dst_generator = SampleGeneratorFace(
                    self.training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=dst_generators_count,
                    raise_on_no_data=False)

                self.set_training_data_generators(
                    [srcdst_generator, src_generator, dst_generator])
예제 #24
0
    def on_initialize(self):
        nn.initialize()
        tf = nn.tf

        class EncBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch, level):
                self.zero_level = level == 0
                self.conv1 = nn.Conv2D(in_ch,
                                       out_ch,
                                       kernel_size=3,
                                       padding='SAME')
                self.conv2 = nn.Conv2D(
                    out_ch,
                    out_ch,
                    kernel_size=4 if self.zero_level else 3,
                    padding='VALID' if self.zero_level else 'SAME')

            def forward(self, x):
                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)

                if not self.zero_level:
                    x = nn.max_pool(x)
                return x

        class DecBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch, level):
                self.zero_level = level == 0
                self.conv1 = nn.Conv2D(
                    in_ch,
                    out_ch,
                    kernel_size=4 if self.zero_level else 3,
                    padding=3 if self.zero_level else 'SAME')
                self.conv2 = nn.Conv2D(out_ch,
                                       out_ch,
                                       kernel_size=3,
                                       padding='SAME')

            def forward(self, x):
                if not self.zero_level:
                    x = nn.upsample2d(x)

                x = tf.nn.leaky_relu(self.conv1(x), 0.2)
                x = tf.nn.leaky_relu(self.conv2(x), 0.2)
                return x

        class FromRGB(nn.ModelBase):
            def on_build(self, out_ch):
                self.conv1 = nn.Conv2D(3,
                                       out_ch,
                                       kernel_size=1,
                                       padding='SAME')

            def forward(self, x):
                return tf.nn.leaky_relu(self.conv1(x), 0.2)

        class ToRGB(nn.ModelBase):
            def on_build(self, in_ch):
                self.conv = nn.Conv2D(in_ch, 3, kernel_size=1, padding='SAME')
                self.convm = nn.Conv2D(in_ch, 1, kernel_size=1, padding='SAME')

            def forward(self, x):
                return tf.nn.sigmoid(self.conv(x)), tf.nn.sigmoid(
                    self.convm(x))

        class Encoder(nn.ModelBase):
            def on_build(self, e_ch, levels):
                self.enc_blocks = {}
                self.from_rgbs = {}
                self.dense_norm = nn.DenseNorm()

                in_ch = e_ch
                out_ch = in_ch
                for level in range(levels, -1, -1):
                    self.max_ch = out_ch = np.clip(out_ch * 2, 0, 512)

                    self.enc_blocks[level] = EncBlock(in_ch, out_ch, level)
                    self.from_rgbs[level] = FromRGB(in_ch)

                    in_ch = out_ch

            def forward(self, inp, stage):
                x = inp

                for level in range(stage, -1, -1):
                    if stage in self.enc_blocks:
                        if level == stage:
                            x = self.from_rgbs[level](x)
                        x = self.enc_blocks[level](x)

                x = nn.flatten(x)
                x = self.dense_norm(x)
                x = nn.reshape_4D(x, 1, 1, self.max_ch)

                return x

            def get_stage_weights(self, stage):
                self.get_weights()
                weights = []
                for level in range(stage, -1, -1):
                    if stage in self.enc_blocks:
                        if level == stage:
                            weights.append(self.from_rgbs[level].get_weights())
                        weights.append(self.enc_blocks[level].get_weights())

                if len(weights) == 0:
                    return []
                elif len(weights) == 1:
                    return weights[0]
                else:
                    return sum(weights[1:], weights[0])

        class Decoder(nn.ModelBase):
            def on_build(self, d_ch, total_levels, levels_range):

                self.dec_blocks = {}
                self.to_rgbs = {}

                level_ch = {}
                ch = d_ch
                for level in range(total_levels, -2, -1):
                    level_ch[level] = ch
                    ch = np.clip(ch * 2, 0, 512)

                out_ch = level_ch[levels_range[1]]
                for level in range(levels_range[1], levels_range[0] - 1, -1):
                    in_ch = level_ch[level - 1]

                    self.dec_blocks[level] = DecBlock(in_ch, out_ch, level)
                    self.to_rgbs[level] = ToRGB(out_ch)

                    out_ch = in_ch

            def forward(self, inp, stage):
                x = inp

                for level in range(stage + 1):
                    if level in self.dec_blocks:
                        x = self.dec_blocks[level](x)
                        if level == stage:
                            x = self.to_rgbs[level](x)
                return x

            def get_stage_weights(self, stage):
                # Call internal get_weights in order to initialize inner logic
                self.get_weights()

                weights = []
                for level in range(stage + 1):
                    if level in self.dec_blocks:
                        weights.append(self.dec_blocks[level].get_weights())
                        if level == stage:
                            weights.append(self.to_rgbs[level].get_weights())

                if len(weights) == 0:
                    return []
                elif len(weights) == 1:
                    return weights[0]
                else:
                    return sum(weights[1:], weights[0])

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.stage = stage = self.options['stage']
        self.start_stage_iter = self.options.get('start_stage_iter', 0)
        self.target_stage_iter = self.options.get('target_stage_iter', 0)

        stage_resolutions = [2**(i + 2) for i in range(self.stage_max + 1)]
        stage_resolution = stage_resolutions[stage]

        ed_dims = 16

        self.pretrain = False
        self.pretrain_just_disabled = False

        masked_training = True

        models_opt_on_gpu = len(devices) == 1 and devices[0].total_mem_gb >= 4
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device == '/CPU:0'

        input_nc = 3
        output_nc = 3
        bgr_shape = (stage_resolution, stage_resolution, output_nc)
        mask_shape = (stage_resolution, stage_resolution, 1)

        self.model_filename_list = []

        with tf.device('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder(tf.float32, (None, ) + bgr_shape)
            self.warped_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape)

            self.target_src = tf.placeholder(tf.float32, (None, ) + bgr_shape)
            self.target_dst = tf.placeholder(tf.float32, (None, ) + bgr_shape)

            self.target_srcm = tf.placeholder(tf.float32,
                                              (None, ) + mask_shape)
            self.target_dstm = tf.placeholder(tf.float32,
                                              (None, ) + mask_shape)

        # Initializing model classes
        with tf.device(models_opt_device):
            self.encoder = Encoder(e_ch=ed_dims,
                                   levels=self.stage_max,
                                   name='encoder')

            self.inter = Decoder(d_ch=ed_dims,
                                 total_levels=self.stage_max,
                                 levels_range=[0, 2],
                                 name='inter')
            self.decoder_src = Decoder(d_ch=ed_dims,
                                       total_levels=self.stage_max,
                                       levels_range=[3, self.stage_max],
                                       name='decoder_src')
            self.decoder_dst = Decoder(d_ch=ed_dims,
                                       total_levels=self.stage_max,
                                       levels_range=[3, self.stage_max],
                                       name='decoder_dst')

            self.model_filename_list += [[self.encoder, 'encoder.npy'],
                                         [self.inter, 'inter.npy'],
                                         [self.decoder_src, 'decoder_src.npy'],
                                         [self.decoder_dst, 'decoder_dst.npy']]

            if self.is_training:
                self.src_dst_all_weights = self.encoder.get_weights(
                ) + self.inter.get_weights() + self.decoder_src.get_weights(
                ) + self.decoder_dst.get_weights()
                self.src_dst_trainable_weights = self.encoder.get_stage_weights(stage) + self.inter.get_stage_weights(stage) \
                               + self.decoder_src.get_stage_weights(stage) \
                               + self.decoder_dst.get_stage_weights(stage)

                # Initialize optimizers
                self.src_dst_opt = nn.RMSprop(lr=2e-4,
                                              lr_dropout=0.3,
                                              name='src_dst_opt')
                self.src_dst_opt.initialize_variables(
                    self.src_dst_all_weights,
                    vars_on_cpu=optimizer_vars_on_cpu)
                self.model_filename_list += [(self.src_dst_opt,
                                              'src_dst_opt.npy')]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_src_dst_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):
                    batch_slice = slice(gpu_id * bs_per_gpu,
                                        (gpu_id + 1) * bs_per_gpu)
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        gpu_warped_src = self.warped_src[batch_slice, :, :, :]
                        gpu_warped_dst = self.warped_dst[batch_slice, :, :, :]
                        gpu_target_src = self.target_src[batch_slice, :, :, :]
                        gpu_target_dst = self.target_dst[batch_slice, :, :, :]
                        gpu_target_srcm = self.target_srcm[
                            batch_slice, :, :, :]
                        gpu_target_dstm = self.target_dstm[
                            batch_slice, :, :, :]

                    # process model tensors

                    gpu_src_code = self.inter(
                        self.encoder(gpu_warped_src, stage), stage)
                    gpu_dst_code = self.inter(
                        self.encoder(gpu_warped_dst, stage), stage)

                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(
                        gpu_src_code, stage)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(
                        gpu_dst_code, stage)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                        gpu_dst_code, stage)

                    import code
                    code.interact(local=dict(globals(), **locals()))

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.gaussian_blur(
                        gpu_target_srcm, max(1, resolution // 32))
                    gpu_target_dstm_blur = nn.gaussian_blur(
                        gpu_target_dstm, max(1, resolution // 32))

                    gpu_target_dst_masked = gpu_target_dst * gpu_target_dstm_blur
                    gpu_target_dst_anti_masked = gpu_target_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_target_srcmasked_opt = gpu_target_src * gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src * gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst * gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_masked = gpu_pred_src_dst * gpu_target_dstm_blur
                    gpu_psd_target_dst_anti_masked = gpu_pred_src_dst * (
                        1.0 - gpu_target_dstm_blur)

                    gpu_src_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_srcmasked_opt,
                                      gpu_pred_src_src_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_src_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_srcmasked_opt -
                                       gpu_pred_src_src_masked_opt),
                        axis=[1, 2, 3])
                    gpu_src_loss += tf.reduce_mean(
                        tf.square(gpu_target_srcm - gpu_pred_src_srcm),
                        axis=[1, 2, 3])

                    gpu_dst_loss = tf.reduce_mean(
                        10 * nn.dssim(gpu_target_dst_masked_opt,
                                      gpu_pred_dst_dst_masked_opt,
                                      max_val=1.0,
                                      filter_size=int(resolution / 11.6)),
                        axis=[1])
                    gpu_dst_loss += tf.reduce_mean(
                        10 * tf.square(gpu_target_dst_masked_opt -
                                       gpu_pred_dst_dst_masked_opt),
                        axis=[1, 2, 3])
                    gpu_dst_loss += tf.reduce_mean(
                        tf.square(gpu_target_dstm - gpu_pred_dst_dstm),
                        axis=[1, 2, 3])

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss
                    gpu_src_dst_loss_gvs += [
                        nn.gradients(gpu_src_dst_loss,
                                     self.src_dst_trainable_weights)
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                if gpu_count == 1:
                    pred_src_src = gpu_pred_src_src_list[0]
                    pred_dst_dst = gpu_pred_dst_dst_list[0]
                    pred_src_dst = gpu_pred_src_dst_list[0]
                    pred_src_srcm = gpu_pred_src_srcm_list[0]
                    pred_dst_dstm = gpu_pred_dst_dstm_list[0]
                    pred_src_dstm = gpu_pred_src_dstm_list[0]

                    src_loss = gpu_src_losses[0]
                    dst_loss = gpu_dst_losses[0]
                    src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
                else:
                    pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
                    pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
                    pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
                    pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
                    pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
                    pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)

                    src_loss = nn.average_tensor_list(gpu_src_losses)
                    dst_loss = nn.average_tensor_list(gpu_dst_losses)
                    src_dst_loss_gv = nn.average_gv_list(gpu_src_dst_loss_gvs)

                src_dst_loss_gv_op = self.src_dst_opt.get_update_op(
                    src_dst_loss_gv)

            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, \
                              warped_dst, target_dst, target_dstm):
                s, d, _ = nn.tf_sess.run(
                    [src_loss, dst_loss, src_dst_loss_gv_op],
                    feed_dict={
                        self.warped_src: warped_src,
                        self.target_src: target_src,
                        self.target_srcm: target_srcm,
                        self.warped_dst: warped_dst,
                        self.target_dst: target_dst,
                        self.target_dstm: target_dstm,
                    })
                s = np.mean(s)
                d = np.mean(d)
                return s, d

            self.src_dst_train = src_dst_train

            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run([
                    pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst,
                    pred_src_dstm
                ],
                                      feed_dict={
                                          self.warped_src: warped_src,
                                          self.warped_dst: warped_dst
                                      })

            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device(f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.inter(self.encoder(self.warped_dst))
                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(
                    gpu_dst_code, stage=stage)
                _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code,
                                                        stage=stage)

            def AE_merge(warped_dst):
                return nn.tf_sess.run(
                    [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm],
                    feed_dict={self.warped_dst: warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(
                self.model_filename_list, "Initializing models"):
            do_init = self.is_first_run()

            if self.pretrain_just_disabled:
                if model == self.inter:
                    do_init = True

            if not do_init:
                do_init = not model.load_weights(
                    self.get_strpath_storage_for_file(filename))

            if do_init:
                model.init_weights()

        # initializing sample generators

        if self.is_training:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL

            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path(
            )
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path(
            )

            cpu_count = multiprocessing.cpu_count()

            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count - src_generators_count

            self.set_training_data_generators([
                SampleGeneratorFace(
                    training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'types':
                        (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR),
                        'resolution':
                        resolution,
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                        'resolution':
                        resolution,
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type,
                                  t.MODE_FACE_MASK_ALL_HULL),
                        'resolution':
                        resolution
                    }],
                    generators_count=src_generators_count),
                SampleGeneratorFace(
                    training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True if self.pretrain else False),
                    output_sample_types=[{
                        'types':
                        (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR),
                        'resolution':
                        resolution
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR),
                        'resolution':
                        resolution
                    }, {
                        'types': (t.IMG_TRANSFORMED, face_type,
                                  t.MODE_FACE_MASK_ALL_HULL),
                        'resolution':
                        resolution
                    }],
                    generators_count=dst_generators_count)
            ])

            self.last_samples = None
예제 #25
0
def main(model_class_name=None,
         saved_models_path=None,
         training_data_src_path=None,
         force_model_name=None,
         input_path=None,
         output_path=None,
         output_mask_path=None,
         aligned_path=None,
         force_gpu_idxs=None,
         cpu_only=None):
    io.log_info("Running merger.\r\n")

    try:
        if not input_path.exists():
            io.log_err('Input directory not found. Please ensure it exists.')
            return

        if not output_path.exists():
            output_path.mkdir(parents=True, exist_ok=True)

        if not output_mask_path.exists():
            output_mask_path.mkdir(parents=True, exist_ok=True)

        if not saved_models_path.exists():
            io.log_err('Model directory not found. Please ensure it exists.')
            return

        # Initialize model
        import models
        model = models.import_model(model_class_name)(
            is_training=False,
            saved_models_path=saved_models_path,
            force_gpu_idxs=force_gpu_idxs,
            cpu_only=cpu_only)

        predictor_func, predictor_input_shape, cfg = model.get_MergerConfig()

        # Preparing MP functions
        predictor_func = MPFunc(predictor_func)

        run_on_cpu = len(nn.getCurrentDeviceConfig().devices) == 0
        xseg_256_extract_func = MPClassFuncOnDemand(
            XSegNet,
            'extract',
            name='XSeg',
            resolution=256,
            weights_file_root=saved_models_path,
            place_model_on_cpu=True,
            run_on_cpu=run_on_cpu)

        face_enhancer_func = MPClassFuncOnDemand(FaceEnhancer,
                                                 'enhance',
                                                 place_model_on_cpu=True,
                                                 run_on_cpu=run_on_cpu)

        is_interactive = io.input_bool("Use interactive merger?",
                                       True) if not io.is_colab() else False

        #         if not is_interactive:
        #             cfg.ask_settings()

        subprocess_count = multiprocessing.cpu_count()
        #         subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()),
        #                                         valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." )

        input_path_image_paths = pathex.get_image_paths(input_path)

        if cfg.type == MergerConfig.TYPE_MASKED:
            if not aligned_path.exists():
                io.log_err(
                    'Aligned directory not found. Please ensure it exists.')
                return

            packed_samples = None
            try:
                packed_samples = samplelib.PackedFaceset.load(aligned_path)
            except:
                io.log_err(
                    f"Error occured while loading samplelib.PackedFaceset.load {str(aligned_path)}, {traceback.format_exc()}"
                )

            if packed_samples is not None:
                io.log_info("Using packed faceset.")

                def generator():
                    for sample in io.progress_bar_generator(
                            packed_samples, "Collecting alignments"):
                        filepath = Path(sample.filename)
                        yield filepath, DFLIMG.load(
                            filepath,
                            loader_func=lambda x: sample.read_raw_file())
            else:

                def generator():
                    for filepath in io.progress_bar_generator(
                            pathex.get_image_paths(aligned_path),
                            "Collecting alignments"):
                        filepath = Path(filepath)
                        yield filepath, DFLIMG.load(filepath)

            alignments = {}
            multiple_faces_detected = False

            for filepath, dflimg in generator():
                if dflimg is None or not dflimg.has_data():
                    io.log_err(f"{filepath.name} is not a dfl image file")
                    continue

                source_filename = dflimg.get_source_filename()
                if source_filename is None:
                    continue

                source_filepath = Path(source_filename)
                source_filename_stem = source_filepath.stem

                if source_filename_stem not in alignments.keys():
                    alignments[source_filename_stem] = []

                alignments_ar = alignments[source_filename_stem]
                alignments_ar.append(
                    (dflimg.get_source_landmarks(), filepath, source_filepath))

                if len(alignments_ar) > 1:
                    multiple_faces_detected = True

            if multiple_faces_detected:
                io.log_info("")
                io.log_info(
                    "Warning: multiple faces detected. Only one alignment file should refer one source file."
                )
                io.log_info("")

            for a_key in list(alignments.keys()):
                a_ar = alignments[a_key]
                if len(a_ar) > 1:
                    for _, filepath, source_filepath in a_ar:
                        io.log_info(
                            f"alignment {filepath.name} refers to {source_filepath.name} "
                        )
                    io.log_info("")

                alignments[a_key] = [a[0] for a in a_ar]

            if multiple_faces_detected:
                io.log_info(
                    "It is strongly recommended to process the faces separatelly."
                )
                io.log_info(
                    "Use 'recover original filename' to determine the exact duplicates."
                )
                io.log_info("")

            frames = [
                InteractiveMergerSubprocessor.Frame(frame_info=FrameInfo(
                    filepath=Path(p),
                    landmarks_list=alignments.get(Path(p).stem, None)))
                for p in input_path_image_paths
            ]

            if multiple_faces_detected:
                io.log_info(
                    "Warning: multiple faces detected. Motion blur will not be used."
                )
                io.log_info("")
            else:
                s = 256
                local_pts = [(s // 2 - 1, s // 2 - 1),
                             (s // 2 - 1, 0)]  #center+up
                frames_len = len(frames)
                for i in io.progress_bar_generator(range(len(frames)),
                                                   "Computing motion vectors"):
                    fi_prev = frames[max(0, i - 1)].frame_info
                    fi = frames[i].frame_info
                    fi_next = frames[min(i + 1, frames_len - 1)].frame_info
                    if len(fi_prev.landmarks_list) == 0 or \
                       len(fi.landmarks_list) == 0 or \
                       len(fi_next.landmarks_list) == 0:
                        continue

                    mat_prev = LandmarksProcessor.get_transform_mat(
                        fi_prev.landmarks_list[0], s, face_type=FaceType.FULL)
                    mat = LandmarksProcessor.get_transform_mat(
                        fi.landmarks_list[0], s, face_type=FaceType.FULL)
                    mat_next = LandmarksProcessor.get_transform_mat(
                        fi_next.landmarks_list[0], s, face_type=FaceType.FULL)

                    pts_prev = LandmarksProcessor.transform_points(
                        local_pts, mat_prev, True)
                    pts = LandmarksProcessor.transform_points(
                        local_pts, mat, True)
                    pts_next = LandmarksProcessor.transform_points(
                        local_pts, mat_next, True)

                    prev_vector = pts[0] - pts_prev[0]
                    next_vector = pts_next[0] - pts[0]

                    motion_vector = pts_next[0] - pts_prev[0]
                    fi.motion_power = npla.norm(motion_vector)

                    motion_vector = motion_vector / fi.motion_power if fi.motion_power != 0 else np.array(
                        [0, 0], dtype=np.float32)

                    fi.motion_deg = -math.atan2(
                        motion_vector[1], motion_vector[0]) * 180 / math.pi

        if len(frames) == 0:
            io.log_info("No frames to merge in input_dir.")
        else:
            if False:
                pass
            else:
                InteractiveMergerSubprocessor(
                    is_interactive=is_interactive,
                    merger_session_filepath=model.get_strpath_storage_for_file(
                        'merger_session.dat'),
                    predictor_func=predictor_func,
                    predictor_input_shape=predictor_input_shape,
                    face_enhancer_func=face_enhancer_func,
                    xseg_256_extract_func=xseg_256_extract_func,
                    merger_config=cfg,
                    frames=frames,
                    frames_root_path=input_path,
                    output_path=output_path,
                    output_mask_path=output_mask_path,
                    model_iter=model.get_iter(),
                    subprocess_count=subprocess_count,
                ).run()

        model.finalize()

    except Exception as e:
        print(traceback.format_exc())
예제 #26
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        self.resolution = resolution = self.options['resolution']

        lowest_dense_res = self.lowest_dense_res = resolution // 32

        class Downscale(nn.ModelBase):
            def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.kernel_size = kernel_size
                super().__init__(*kwargs)

            def on_build(self, *args, **kwargs ):
                self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME')

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                return x

            def get_out_ch(self):
                return self.out_ch

        class Upscale(nn.ModelBase):
            def on_build(self, in_ch, out_ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')

            def forward(self, x):
                x = self.conv1(x)
                x = tf.nn.leaky_relu(x, 0.1)
                x = nn.depth_to_space(x, 2)
                return x

        class ResidualBlock(nn.ModelBase):
            def on_build(self, ch, kernel_size=3 ):
                self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
                self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')

            def forward(self, inp):
                x = self.conv1(inp)
                x = tf.nn.leaky_relu(x, 0.2)
                x = self.conv2(x)
                x = tf.nn.leaky_relu(inp+x, 0.2)
                return x

        class Encoder(nn.ModelBase):
            def on_build(self, in_ch, e_ch, ae_ch):
                self.down1 = Downscale(in_ch, e_ch, kernel_size=5)
                self.res1 = ResidualBlock(e_ch)
                self.down2 = Downscale(e_ch, e_ch*2, kernel_size=5)
                self.down3 = Downscale(e_ch*2, e_ch*4, kernel_size=5)
                self.down4 = Downscale(e_ch*4, e_ch*8, kernel_size=5)
                self.down5 = Downscale(e_ch*8, e_ch*8, kernel_size=5)
                self.res5 = ResidualBlock(e_ch*8)
                self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*e_ch*8, ae_ch )

            def forward(self, inp):
                x = inp
                x = self.down1(x)
                x = self.res1(x)
                x = self.down2(x)
                x = self.down3(x)
                x = self.down4(x)
                x = self.down5(x)
                x = self.res5(x)
                x = nn.flatten(x)
                x = nn.pixel_norm(x, axes=-1)
                x = self.dense1(x)
                return x


        class Inter(nn.ModelBase):
            def __init__(self, ae_ch, ae_out_ch, **kwargs):
                self.ae_ch, self.ae_out_ch = ae_ch, ae_out_ch
                super().__init__(**kwargs)

            def on_build(self):
                ae_ch, ae_out_ch = self.ae_ch, self.ae_out_ch
                self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )

            def forward(self, inp):
                x = inp
                x = self.dense2(x)
                x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
                return x

            def get_out_ch(self):
                return self.ae_out_ch

        class Decoder(nn.ModelBase):
            def on_build(self, in_ch, d_ch, d_mask_ch ):
                self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
                self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
                self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
                self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)

                self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
                self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
                self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
                self.res3 = ResidualBlock(d_ch*2, kernel_size=3)

                self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
                self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
                self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
                self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
                self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
                self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME')

                self.out_conv  = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
                self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
                self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
                self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')

            def forward(self, inp):
                z = inp

                x = self.upscale0(z)
                x = self.res0(x)
                x = self.upscale1(x)
                x = self.res1(x)
                x = self.upscale2(x)
                x = self.res2(x)
                x = self.upscale3(x)
                x = self.res3(x)

                x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
                                                                 self.out_conv1(x),
                                                                 self.out_conv2(x),
                                                                 self.out_conv3(x)), nn.conv2d_ch_axis), 2) )

                m = self.upscalem0(z)
                m = self.upscalem1(m)
                m = self.upscalem2(m)
                m = self.upscalem3(m)
                m = self.upscalem4(m)
                m = tf.nn.sigmoid(self.out_convm(m))
                return x, m

        self.face_type = {'wf' : FaceType.WHOLE_FACE,
                          'head' : FaceType.HEAD}[ self.options['face_type'] ]

        if 'eyes_prio' in self.options:
            self.options.pop('eyes_prio')

        eyes_mouth_prio = self.options['eyes_mouth_prio']

        ae_dims = self.ae_dims = self.options['ae_dims']
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        morph_factor = self.options['morph_factor']
        
        pretrain = self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)
            
        self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
        random_warp = False if self.pretrain else self.options['random_warp']
        random_src_flip = self.random_src_flip if not self.pretrain else True
        random_dst_flip = self.random_dst_flip if not self.pretrain else True
        
        if self.pretrain:
            self.options_show_override['gan_power'] = 0.0
            self.options_show_override['random_warp'] = False
            self.options_show_override['lr_dropout'] = 'n'
            self.options_show_override['uniform_yaw'] = True
            
        masked_training = self.options['masked_training']
        ct_mode = self.options['ct_mode']
        if ct_mode == 'none':
            ct_mode = None

        models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch=3
        bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        self.model_filename_list = []

        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
            self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')

            self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
            self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')

            self.target_srcm    = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
            self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
            self.target_dstm    = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
            self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')

            self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')

        # Initializing model classes

        with tf.device (models_opt_device):
            self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, ae_ch=ae_dims,  name='encoder')
            self.inter_src  = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_src')
            self.inter_dst  = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_dst')
            self.decoder = Decoder(in_ch=ae_dims, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')

            self.model_filename_list += [   [self.encoder,  'encoder.npy'],
                                            [self.inter_src, 'inter_src.npy'],
                                            [self.inter_dst , 'inter_dst.npy'],
                                            [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                if gan_power != 0:
                    self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
                    self.model_filename_list += [ [self.GAN, 'GAN.npy'] ]

                # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0
                
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0

                self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
                if pretrain:
                    self.trainable_weights = self.encoder.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
                else:
                    self.trainable_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()

                self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]

                if gan_power != 0:
                    self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt')
                    self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
                    self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gvs = []
            gpu_GAN_loss_gvs = []
            gpu_D_code_loss_gvs = []
            gpu_D_src_dst_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src      = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst      = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src      = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst      = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm     = self.target_srcm[batch_slice,:,:,:]
                        gpu_target_srcm_em  = self.target_srcm_em[batch_slice,:,:,:]
                        gpu_target_dstm     = self.target_dstm[batch_slice,:,:,:]
                        gpu_target_dstm_em  = self.target_dstm_em[batch_slice,:,:,:]

                    # process model tensors
                    gpu_src_code = self.encoder (gpu_warped_src)
                    gpu_dst_code = self.encoder (gpu_warped_dst)
                    
                    if pretrain:
                        gpu_src_inter_src_code = self.inter_src (gpu_src_code)
                        gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
                        gpu_src_code = gpu_src_inter_src_code * nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor)
                        gpu_dst_code = gpu_src_dst_code = gpu_dst_inter_dst_code * nn.random_binomial( [bs_per_gpu, gpu_dst_inter_dst_code.shape.as_list()[1], 1,1] , p=0.25)
                    else:
                        gpu_src_inter_src_code = self.inter_src (gpu_src_code)
                        gpu_src_inter_dst_code = self.inter_dst (gpu_src_code)
                        gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
                        gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)

                        inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor)
                        gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
                        gpu_dst_code = gpu_dst_inter_dst_code

                        ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32)
                        gpu_src_dst_code =  tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0],   [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]),
                                                        tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 )

                    gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                    gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2

                    gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )
                    gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2

                    gpu_target_dst_anti_masked = gpu_target_dst*(1.0-gpu_target_dstm_blur)
                    gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt  = gpu_target_dst*gpu_target_dstm_blur if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
                    gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur)
                    
                    if resolution < 256:
                        gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    else:
                        gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                        gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
                    if eyes_mouth_prio:
                        gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
                    gpu_dst_loss += 0.1*tf.reduce_mean(tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
                    gpu_dst_losses += [gpu_dst_loss]

                    if not pretrain:
                        if resolution < 256:
                            gpu_src_loss =  tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                        else:
                            gpu_src_loss =  tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                            gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
                        gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])

                        if eyes_mouth_prio:
                            gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3])

                        gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
                    else:
                        gpu_src_loss = gpu_dst_loss
                    
                    gpu_src_losses += [gpu_src_loss]
                    
                    if pretrain:
                        gpu_G_loss = gpu_dst_loss
                    else:     
                        gpu_G_loss = gpu_src_loss + gpu_dst_loss

                    def DLossOnes(logits):
                        return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3])

                    def DLossZeros(logits):
                        return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])

                    if gan_power != 0:
                        gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked_opt)
                        gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt)
                        gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked_opt)
                        gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt)

                        gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d)   + DLossOnes (gpu_target_src_d2) + \
                                              DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
                                              DLossOnes (gpu_target_dst_d)   + DLossOnes (gpu_target_dst_d2) + \
                                              DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
                                             ) * ( 1.0 / 8)

                        gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ]

                        gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
                                       DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
                                      ) * gan_power

                        if masked_training:
                            # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
                            gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
                            gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )

                    gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device(f'/CPU:0'):
                pred_src_src  = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

            with tf.device (models_opt_device):
                src_loss = tf.concat(gpu_src_losses, 0)
                dst_loss = tf.concat(gpu_dst_losses, 0)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))

                if gan_power != 0:
                    src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )
                    #GAN_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gvs) )


            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em,  \
                              warped_dst, target_dst, target_dstm, target_dstm_em, ):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm:target_srcm,
                                                       self.target_srcm_em:target_srcm_em,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm:target_dstm,
                                                       self.target_dstm_em:target_dstm_em,
                                                       })
                return s, d
            self.src_dst_train = src_dst_train

            if gan_power != 0:
                def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em,  \
                                    warped_dst, target_dst, target_dstm, target_dstm_em, ):
                    nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
                                                                           self.target_src :target_src,
                                                                           self.target_srcm:target_srcm,
                                                                           self.target_srcm_em:target_srcm_em,
                                                                           self.warped_dst :warped_dst,
                                                                           self.target_dst :target_dst,
                                                                           self.target_dstm:target_dstm,
                                                                           self.target_dstm_em:target_dstm_em})
                self.D_src_dst_train = D_src_dst_train


            def AE_view(warped_src, warped_dst, morph_value):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })

            self.AE_view = AE_view
        else:
            #Initializing merge function
            with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
                gpu_dst_code = self.encoder (self.warped_dst)
                gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
                gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)

                ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32)
                gpu_src_dst_code =  tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0],   [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]),
                                                 tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 )

                gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)

            def AE_merge(warped_dst, morph_value):
                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if model == self.inter_src or model == self.inter_dst:
                    do_init = True
            else:
                do_init = self.is_first_run()
                if self.is_training and gan_power != 0 and model == self.GAN:
                    if self.gan_model_changed:
                        do_init = True
                        
            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
            if do_init:
                model.init_weights()


        ###############

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None


            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            if ct_mode is not None:
                src_generators_count = int(src_generators_count * 1.5)

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=random_src_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
                        generators_count=dst_generators_count )
                             ])

            self.last_src_samples_loss = []
            self.last_dst_samples_loss = []
            if self.pretrain_just_disabled:
                self.update_sample_for_preview(force_new=True)
예제 #27
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        self.resolution = resolution = self.options['resolution']
        self.face_type = {'h'  : FaceType.HALF,
                          'mf' : FaceType.MID_FULL,
                          'f'  : FaceType.FULL,
                          'wf' : FaceType.WHOLE_FACE,
                          'head' : FaceType.HEAD}[ self.options['face_type'] ]

        eyes_prio = self.options['eyes_prio']

        archi_split = self.options['archi'].split('-')

        if len(archi_split) == 2:
            archi_type, archi_opts = archi_split
        elif len(archi_split) == 1:
            archi_type, archi_opts = archi_split[0], None

        ae_dims = self.options['ae_dims']
        e_dims = self.options['e_dims']
        d_dims = self.options['d_dims']
        d_mask_dims = self.options['d_mask_dims']
        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
        random_warp = False if self.pretrain else self.options['random_warp']

        if self.pretrain:
            self.options_show_override['gan_power'] = 0.0
            self.options_show_override['random_warp'] = False
            self.options_show_override['lr_dropout'] = 'n'
            self.options_show_override['face_style_power'] = 0.0
            self.options_show_override['bg_style_power'] = 0.0
            self.options_show_override['uniform_yaw'] = True

        masked_training = self.options['masked_training']
        import dfl
        dfl.load_config()
        masked_training = dfl.get_config("masked_training", "1") == "1"
        ct_mode = self.options['ct_mode']
        if ct_mode == 'none':
            ct_mode = None

        models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch=3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        self.model_filename_list = []

        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
            self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)

            self.target_src = tf.placeholder (nn.floatx, bgr_shape)
            self.target_dst = tf.placeholder (nn.floatx, bgr_shape)

            self.target_srcm_all = tf.placeholder (nn.floatx, mask_shape)
            self.target_dstm_all = tf.placeholder (nn.floatx, mask_shape)

        # Initializing model classes
        model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)

        with tf.device (models_opt_device):
            if 'df' in archi_type:
                self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))

                self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
                inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))

                self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
                self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')

                self.model_filename_list += [ [self.encoder,     'encoder.npy'    ],
                                              [self.inter,       'inter.npy'      ],
                                              [self.decoder_src, 'decoder_src.npy'],
                                              [self.decoder_dst, 'decoder_dst.npy']  ]

                if self.is_training:
                    if self.options['true_face_power'] != 0:
                        self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
                        self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]

            elif 'liae' in archi_type:
                self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
                encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))

                self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
                self.inter_B  = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')

                inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
                inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
                inters_out_ch = inter_AB_out_ch+inter_B_out_ch
                self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')

                self.model_filename_list += [ [self.encoder,  'encoder.npy'],
                                              [self.inter_AB, 'inter_AB.npy'],
                                              [self.inter_B , 'inter_B.npy'],
                                              [self.decoder , 'decoder.npy'] ]

            if self.is_training:
                if gan_power != 0:
                    self.D_src = nn.UNetPatchDiscriminator(patch_size=resolution//16, in_ch=input_ch, name="D_src")
                    self.model_filename_list += [ [self.D_src, 'D_src_v2.npy'] ]

                # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0
                clipnorm = 1.0 if self.options['clipgrad'] else 0.0

                if 'df' in archi_type:
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights()
                elif 'liae' in archi_type:
                    self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()

                self.src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
                self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
                self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]

                if self.options['true_face_power'] != 0:
                    self.D_code_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt')
                    self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
                    self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]

                if gan_power != 0:
                    self.D_src_dst_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
                    self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
                    self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_v2_opt.npy') ]

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)


            # Compute losses per GPU
            gpu_pred_src_src_list = []
            gpu_pred_dst_dst_list = []
            gpu_pred_src_dst_list = []
            gpu_pred_src_srcm_list = []
            gpu_pred_dst_dstm_list = []
            gpu_pred_src_dstm_list = []

            gpu_src_losses = []
            gpu_dst_losses = []
            gpu_G_loss_gvs = []
            gpu_D_code_loss_gvs = []
            gpu_D_src_dst_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src      = self.warped_src [batch_slice,:,:,:]
                        gpu_warped_dst      = self.warped_dst [batch_slice,:,:,:]
                        gpu_target_src      = self.target_src [batch_slice,:,:,:]
                        gpu_target_dst      = self.target_dst [batch_slice,:,:,:]
                        gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:]
                        gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]

                    # process model tensors
                    if 'df' in archi_type:
                        gpu_src_code     = self.inter(self.encoder(gpu_warped_src))
                        gpu_dst_code     = self.inter(self.encoder(gpu_warped_dst))
                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)

                    elif 'liae' in archi_type:
                        gpu_src_code = self.encoder (gpu_warped_src)
                        gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
                        gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis  )
                        gpu_dst_code = self.encoder (gpu_warped_dst)
                        gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                        gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                        gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
                        gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )

                        gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
                        gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
                        gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)

                    gpu_pred_src_src_list.append(gpu_pred_src_src)
                    gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
                    gpu_pred_src_dst_list.append(gpu_pred_src_dst)

                    gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
                    gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
                    gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)

                    # unpack masks from one combined mask
                    gpu_target_srcm      = tf.clip_by_value (gpu_target_srcm_all, 0, 1)
                    gpu_target_dstm      = tf.clip_by_value (gpu_target_dstm_all, 0, 1)
                    gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)
                    gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)

                    gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm,  max(1, resolution // 32) )
                    gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2

                    gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm,  max(1, resolution // 32) )
                    gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
                    gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2

                    gpu_target_dst_masked      = gpu_target_dst*gpu_target_dstm_blur
                    gpu_target_dst_style_masked      = gpu_target_dst*gpu_target_dstm_style_blur
                    gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur)

                    gpu_target_src_masked_opt  = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
                    gpu_target_dst_masked_opt  = gpu_target_dst_masked if masked_training else gpu_target_dst

                    gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
                    gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst

                    gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
                    gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)

                    if resolution < 256:
                        gpu_src_loss =  tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    else:
                        gpu_src_loss =  tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                        gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])

                    if eyes_prio:
                        gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3])

                    gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )

                    face_style_power = self.options['face_style_power'] / 100.0
                    if face_style_power != 0 and not self.pretrain:
                        gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)

                    bg_style_power = self.options['bg_style_power'] / 100.0
                    if bg_style_power != 0 and not self.pretrain:
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked,  gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                        gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] )

                    if resolution < 256:
                        gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                    else:
                        gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
                        gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square(  gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])


                    if eyes_prio:
                        gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])

                    gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )

                    gpu_src_losses += [gpu_src_loss]
                    gpu_dst_losses += [gpu_dst_loss]

                    gpu_G_loss = gpu_src_loss + gpu_dst_loss

                    def DLoss(labels,logits):
                        return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])

                    if self.options['true_face_power'] != 0:
                        gpu_src_code_d = self.code_discriminator( gpu_src_code )
                        gpu_src_code_d_ones  = tf.ones_like (gpu_src_code_d)
                        gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
                        gpu_dst_code_d = self.code_discriminator( gpu_dst_code )
                        gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)

                        gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)

                        gpu_D_code_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
                                           DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5

                        gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ]

                    if gan_power != 0:
                        gpu_pred_src_src_d, \
                        gpu_pred_src_src_d2           = self.D_src(gpu_pred_src_src_masked_opt)

                        gpu_pred_src_src_d_ones  = tf.ones_like (gpu_pred_src_src_d)
                        gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d)

                        gpu_pred_src_src_d2_ones  = tf.ones_like (gpu_pred_src_src_d2)
                        gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2)

                        gpu_target_src_d, \
                        gpu_target_src_d2            = self.D_src(gpu_target_src_masked_opt)

                        gpu_target_src_d_ones    = tf.ones_like(gpu_target_src_d)
                        gpu_target_src_d2_ones    = tf.ones_like(gpu_target_src_d2)

                        gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones      , gpu_target_src_d) + \
                                              DLoss(gpu_pred_src_src_d_zeros   , gpu_pred_src_src_d) ) * 0.5 + \
                                             (DLoss(gpu_target_src_d2_ones      , gpu_target_src_d2) + \
                                              DLoss(gpu_pred_src_src_d2_zeros   , gpu_pred_src_src_d2) ) * 0.5

                        gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights()

                        gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d)  + \
                                                 DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))

                    gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]


            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                pred_src_src  = nn.concat(gpu_pred_src_src_list, 0)
                pred_dst_dst  = nn.concat(gpu_pred_dst_dst_list, 0)
                pred_src_dst  = nn.concat(gpu_pred_src_dst_list, 0)
                pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0)
                pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0)
                pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0)

                src_loss = tf.concat(gpu_src_losses, 0)
                dst_loss = tf.concat(gpu_dst_losses, 0)
                src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))

                if self.options['true_face_power'] != 0:
                    D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs))

                if gan_power != 0:
                    src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )


            # Initializing training and view functions
            def src_dst_train(warped_src, target_src, target_srcm_all, \
                              warped_dst, target_dst, target_dstm_all):
                s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
                                            feed_dict={self.warped_src :warped_src,
                                                       self.target_src :target_src,
                                                       self.target_srcm_all:target_srcm_all,
                                                       self.warped_dst :warped_dst,
                                                       self.target_dst :target_dst,
                                                       self.target_dstm_all:target_dstm_all,
                                                       })
                return s, d
            self.src_dst_train = src_dst_train

            if self.options['true_face_power'] != 0:
                def D_train(warped_src, warped_dst):
                    nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
                self.D_train = D_train

            if gan_power != 0:
                def D_src_dst_train(warped_src, target_src, target_srcm_all, \
                                    warped_dst, target_dst, target_dstm_all):
                    nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
                                                                           self.target_src :target_src,
                                                                           self.target_srcm_all:target_srcm_all,
                                                                           self.warped_dst :warped_dst,
                                                                           self.target_dst :target_dst,
                                                                           self.target_dstm_all:target_dstm_all})
                self.D_src_dst_train = D_src_dst_train


            def AE_view(warped_src, warped_dst):
                return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
                                            feed_dict={self.warped_src:warped_src,
                                                    self.warped_dst:warped_dst})
            self.AE_view = AE_view
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                if 'df' in archi_type:
                    gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

                elif 'liae' in archi_type:
                    gpu_dst_code = self.encoder (self.warped_dst)
                    gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
                    gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
                    gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
                    gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)

                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)


            def AE_merge( warped_dst):
                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            if self.pretrain_just_disabled:
                do_init = False
                if 'df' in archi_type:
                    if model == self.inter:
                        do_init = True
                elif 'liae' in archi_type:
                    if model == self.inter_AB or model == self.inter_B:
                        do_init = True
            else:
                do_init = self.is_first_run()

            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
            training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()

            random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None

            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2
            if ct_mode is not None:
                src_generators_count = int(src_generators_count * 1.5)

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode,                                           'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
                        generators_count=src_generators_count ),

                    SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
                        sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR,                                                                'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False                      , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G,   'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain,
                        generators_count=dst_generators_count )
                             ])

            self.last_src_samples_loss = []
            self.last_dst_samples_loss = []

            if self.pretrain_just_disabled:
                self.update_sample_for_preview(force_new=True)
예제 #28
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices
        self.model_data_format = "NHWC"#"NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        self.resolution = resolution = self.options['resolution']
        self.face_type = {'h'  : FaceType.HALF,
                          'mf' : FaceType.MID_FULL,
                          'f'  : FaceType.FULL,
                          'wf' : FaceType.WHOLE_FACE,
                          'head' : FaceType.HEAD}[ self.options['face_type'] ]


        models_opt_on_gpu = True#False if len(devices) == 0 else self.options['models_opt_on_gpu']
        models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
        optimizer_vars_on_cpu = models_opt_device=='/CPU:0'

        input_ch=3
        bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
        mask_shape = nn.get4Dshape(resolution,resolution,1)
        self.model_filename_list = []

        class BaseModel(nn.ModelBase):
            def on_build(self, in_ch, base_ch, out_ch=None):
                self.convs = [ nn.Conv2D( in_ch, base_ch, kernel_size=7, strides=1, padding='SAME'),
                               nn.Conv2D( base_ch, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch, base_ch*2, kernel_size=3, strides=2, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*2, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch*2, base_ch*4, kernel_size=3, strides=2, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*4, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch*4, base_ch*8, kernel_size=3, strides=2, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*8, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME')
                             ]

                self.frns = [ None,
                              nn.FRNorm2D(base_ch),
                              nn.FRNorm2D(base_ch*2),
                              nn.FRNorm2D(base_ch*2),
                              nn.FRNorm2D(base_ch*4),
                              nn.FRNorm2D(base_ch*4),
                              nn.FRNorm2D(base_ch*8),
                              nn.FRNorm2D(base_ch*8),
                            ]

                self.tlus = [ nn.TLU(base_ch),
                              nn.TLU(base_ch),
                              nn.TLU(base_ch*2),
                              nn.TLU(base_ch*2),
                              nn.TLU(base_ch*4),
                              nn.TLU(base_ch*4),
                              nn.TLU(base_ch*8),
                              nn.TLU(base_ch*8),
                            ]

                if out_ch is not None:
                    self.out_conv = nn.Conv2D( base_ch*8, out_ch, kernel_size=1, strides=1,  use_bias=False, padding='VALID')
                else:
                    self.out_conv = None

            def forward(self, inp):
                x = inp

                for i in range(len(self.convs)):
                    x = self.convs[i](x)
                    if self.frns[i] is not None:
                        x = self.frns[i](x)
                    x = self.tlus[i](x)

                if self.out_conv is not None:
                    x = self.out_conv(x)
                return x

        class Regressor(nn.ModelBase):
            def on_build(self, lmrks_ch, base_ch, out_ch):
                self.convs = [ nn.Conv2D( base_ch*8+lmrks_ch, base_ch*8, kernel_size=3, strides=1, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*8, base_ch*8*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch*8, base_ch*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*4, base_ch*4*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch*4, base_ch*2, kernel_size=3, strides=1, use_bias=False, padding='SAME'),
                               nn.Conv2D( base_ch*2, base_ch*2*4, kernel_size=3, strides=1, use_bias=False, padding='SAME'),

                               nn.Conv2D( base_ch*2, base_ch, kernel_size=3, strides=1, use_bias=False, padding='SAME'),
                             ]

                self.frns = [ nn.FRNorm2D(base_ch*8),
                              nn.FRNorm2D(base_ch*8*4),
                              nn.FRNorm2D(base_ch*4),
                              nn.FRNorm2D(base_ch*4*4),
                              nn.FRNorm2D(base_ch*2),
                              nn.FRNorm2D(base_ch*2*4),
                              nn.FRNorm2D(base_ch),
                            ]

                self.tlus = [ nn.TLU(base_ch*8),
                              nn.TLU(base_ch*8*4),
                              nn.TLU(base_ch*4),
                              nn.TLU(base_ch*4*4),
                              nn.TLU(base_ch*2),
                              nn.TLU(base_ch*2*4),
                              nn.TLU(base_ch),
                            ]

                self.use_upscale = [ False,
                                    True,
                                    False,
                                    True,
                                    False,
                                    True,
                                    False,
                                  ]

                self.out_conv = nn.Conv2D( base_ch, out_ch, kernel_size=3, strides=1, padding='SAME')

            def forward(self, inp):
                x = inp

                for i in range(len(self.convs)):
                    x = self.convs[i](x)
                    x = self.frns[i](x)
                    x = self.tlus[i](x)

                    if self.use_upscale[i]:
                        x = nn.depth_to_space(x, 2)

                x = self.out_conv(x)
                x = tf.nn.sigmoid(x)
                return x

        def get_coord(x, other_axis, axis_size):
            # get "x-y" coordinates:
            g_c_prob = tf.reduce_mean(x, axis=other_axis)  # B,W,NMAP
            g_c_prob = tf.nn.softmax(g_c_prob, axis=1)  # B,W,NMAP
            coord_pt = tf.to_float(tf.linspace(-1.0, 1.0, axis_size)) # W
            coord_pt = tf.reshape(coord_pt, [1, axis_size, 1])
            g_c = tf.reduce_sum(g_c_prob * coord_pt, axis=1)
            return g_c, g_c_prob

        def get_gaussian_maps(mu_x, mu_y, width, height, inv_std=10.0, mode='rot'):
            """
            Generates [B,SHAPE_H,SHAPE_W,NMAPS] tensor of 2D gaussians,
            given the gaussian centers: MU [B, NMAPS, 2] tensor.
            STD: is the fixed standard dev.
            """
            y = tf.to_float(tf.linspace(-1.0, 1.0, width))
            x = tf.to_float(tf.linspace(-1.0, 1.0, height))

            if mode in ['rot', 'flat']:
                mu_y, mu_x = mu_y[...,None,None], mu_x[...,None,None]

                y = tf.reshape(y, [1, 1, width, 1])
                x = tf.reshape(x, [1, 1, 1, height])

                g_y = tf.square(y - mu_y)
                g_x = tf.square(x - mu_x)
                dist = (g_y + g_x) * inv_std**2

                if mode == 'rot':
                    g_yx = tf.exp(-dist)
                else:
                    g_yx = tf.exp(-tf.pow(dist + 1e-5, 0.25))

            elif mode == 'ankush':
                y = tf.reshape(y, [1, 1, width])
                x = tf.reshape(x, [1, 1, height])


                g_y = tf.exp(-tf.sqrt(1e-4 + tf.abs((mu_y[...,None] - y) * inv_std)))
                g_x = tf.exp(-tf.sqrt(1e-4 + tf.abs((mu_x[...,None] - x) * inv_std)))

                g_y = tf.expand_dims(g_y, axis=3)
                g_x = tf.expand_dims(g_x, axis=2)
                g_yx = tf.matmul(g_y, g_x)  # [B, NMAPS, H, W]

            else:
                raise ValueError('Unknown mode: ' + str(mode))

            g_yx = tf.transpose(g_yx, perm=[0, 2, 3, 1])
            return g_yx

        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
            self.target_src = tf.placeholder (nn.floatx, bgr_shape)



        # Initializing model classes
        #model_archi = nn.DeepFakeArchi(resolution, mod='uhd' if 'uhd' in archi else None)
        self.landmarks_count = 512
        self.n_ch = 32
        with tf.device (models_opt_device):
            self.detector = BaseModel(3, self.n_ch, out_ch=self.landmarks_count, name='Detector')
            self.extractor = BaseModel(3, self.n_ch, name='Extractor')
            self.regressor = Regressor(self.landmarks_count, self.n_ch, 3, name='Regressor')



            self.model_filename_list += [ [self.detector,  'detector.npy'],
                                          [self.extractor, 'extractor.npy'],
                                          [self.regressor, 'regressor.npy'] ]

            if self.is_training:
                 # Initialize optimizers
                lr=5e-5
                lr_dropout = 0.3#0.3 if self.options['lr_dropout'] and not self.pretrain else 1.0
                clipnorm = 0.0#1.0 if self.options['clipgrad'] else 0.0
                self.model_opt = nn.RMSprop(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='model_opt')
                self.model_filename_list += [ (self.model_opt, 'model_opt.npy') ]

                self.model_trainable_weights = self.detector.get_weights() + self.extractor.get_weights() + self.regressor.get_weights()
                self.model_opt.initialize_variables (self.model_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)



        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices) )
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size( gpu_count*bs_per_gpu)

            # Compute losses per GPU
            gpu_src_rec_list = []
            gauss_mu_list = []
            
            gpu_src_losses = []
            gpu_G_loss_gvs = []
            for gpu_id in range(gpu_count):
                with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
                        gpu_warped_src      = self.warped_src [batch_slice,:,:,:]
                        gpu_target_src      = self.target_src [batch_slice,:,:,:]

                    # process model tensors

                    gpu_src_feat     = self.extractor(gpu_warped_src)
                    gpu_src_heatmaps = self.detector(gpu_target_src)

                    gauss_y, gauss_y_prob = get_coord(gpu_src_heatmaps, 2, gpu_src_heatmaps.shape.as_list()[1] )
                    gauss_x, gauss_x_prob = get_coord(gpu_src_heatmaps, 1, gpu_src_heatmaps.shape.as_list()[2] )
                    gauss_mu = tf.stack ( (gauss_x, gauss_y), -1)
                    
                    dist_loss = []
                    for i in range(self.landmarks_count):
                        
                        t = tf.concat( (gauss_mu[:,0:i], gauss_mu[:,i+1:] ), axis=1 )
                        
                        
                        diff = t - gauss_mu[:,i:i+1]
                        dist = tf.sqrt( diff[...,0]**2+diff[...,1]**2 )
                        
                        dist_loss += [ tf.reduce_mean(2.0 - dist,-1)  ]
                        
                    dist_loss = sum(dist_loss) / self.landmarks_count
                    #import code
                    #code.interact(local=dict(globals(), **locals()))

                    
                    
                    gauss_xy = get_gaussian_maps ( gauss_x, gauss_y, 16, 16 )

                    gpu_src_rec = self.regressor( tf.concat ( (gpu_src_feat, gauss_xy), -1) )

                    gpu_src_rec_list.append(gpu_src_rec)
                    gauss_mu_list.append(gauss_mu)
                    
                    gpu_src_loss =  tf.reduce_mean ( 10*nn.dssim(gpu_target_src, gpu_src_rec, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
                    gpu_src_loss += tf.reduce_mean ( 10*tf.square (gpu_target_src - gpu_src_rec), axis=[1,2,3])
                    gpu_src_loss += dist_loss
                    
                    
                    gpu_src_losses += [gpu_src_loss]

                    gpu_G_loss_gvs += [ nn.gradients ( gpu_src_loss, self.model_trainable_weights ) ]
                    
            # Average losses and gradients, and create optimizer update ops
            with tf.device (models_opt_device):
                src_rec  = nn.concat(gpu_src_rec_list, 0)
                gauss_mu = nn.concat(gauss_mu_list, 0)
                src_loss = tf.concat(gpu_src_losses, 0)
                loss_gv_op = self.model_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))

            # Initializing training and view functions
            def ae_train(warped_src, target_src):
                s, _ = nn.tf_sess.run ( [ src_loss, loss_gv_op], feed_dict={self.warped_src:warped_src, self.target_src:target_src})
                return s
            self.ae_train = ae_train

            def AE_view(warped_src, target_src):
                return nn.tf_sess.run ( [src_rec, gauss_mu], feed_dict={self.warped_src:warped_src, self.target_src:target_src})
            self.AE_view = AE_view
            
        else:
            # Initializing merge function
            with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
                if 'df' in archi:
                    gpu_dst_code     = self.inter(self.encoder(self.warped_dst))
                    gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
                    _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)

            def AE_merge( warped_dst):
                return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})

            self.AE_merge = AE_merge

        # Loading/initializing all models/optimizers weights
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
            do_init = self.is_first_run()
            if not do_init:
                do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
            if do_init:
                model.init_weights()

        # initializing sample generators
        if self.is_training:
            training_data_src_path = self.training_data_src_path
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_generators_count = cpu_count // 2

            self.set_training_data_generators ([
                    SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size()*2,
                        sample_process_options=SampleProcessor.Options(random_flip=False),
                        output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                                {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
                                              ],
                        generators_count=src_generators_count ),
                             ])