예제 #1
0
def apply_xseg(input_path, model_path):
    if not input_path.exists():
        raise ValueError(f'{input_path} not found. Please ensure it exists.')

    if not model_path.exists():
        raise ValueError(f'{model_path} not found. Please ensure it exists.')

    io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')

    device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
    nn.initialize(device_config)

    xseg = XSegNet(name='XSeg',
                   load_weights=True,
                   weights_file_root=model_path,
                   data_format=nn.data_format,
                   raise_on_no_model_files=True)
    res = xseg.get_resolution()

    images_paths = pathex.get_image_paths(input_path, return_Path_class=True)

    for filepath in io.progress_bar_generator(images_paths, "Processing"):
        dflimg = DFLIMG.load(filepath)
        if dflimg is None or not dflimg.has_data():
            io.log_info(f'{filepath} is not a DFLIMG')
            continue

        img = cv2_imread(filepath).astype(np.float32) / 255.0
        h, w, c = img.shape
        if w != res:
            img = cv2.resize(img, (res, res), interpolation=cv2.INTER_CUBIC)
            if len(img.shape) == 2:
                img = img[..., None]

        mask = xseg.extract(img)
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1

        dflimg.set_xseg_mask(mask)
        dflimg.save()
예제 #2
0
def apply_xseg(input_path, model_path):
    if not input_path.exists():
        raise ValueError(f'{input_path} not found. Please ensure it exists.')

    if not model_path.exists():
        raise ValueError(f'{model_path} not found. Please ensure it exists.')
        
    face_type = None
    
    model_dat = model_path / 'XSeg_data.dat'
    if model_dat.exists():
        dat = pickle.loads( model_dat.read_bytes() )
        dat_options = dat.get('options', None)
        if dat_options is not None:
            face_type = dat_options.get('face_type', None)
        
        
        
    if face_type is None:
        face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
        if face_type == 'same':
            face_type = None
    
    if face_type is not None:
        face_type = {'h'  : FaceType.HALF,
                     'mf' : FaceType.MID_FULL,
                     'f'  : FaceType.FULL,
                     'wf' : FaceType.WHOLE_FACE,
                     'head' : FaceType.HEAD}[face_type]
                     
    io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')

    device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
    nn.initialize(device_config)
        
    
    
    xseg = XSegNet(name='XSeg', 
                    load_weights=True,
                    weights_file_root=model_path,
                    data_format=nn.data_format,
                    raise_on_no_model_files=True)
    xseg_res = xseg.get_resolution()
              
    images_paths = pathex.get_image_paths(input_path, return_Path_class=True)
    
    for filepath in io.progress_bar_generator(images_paths, "Processing"):
        dflimg = DFLIMG.load(filepath)
        if dflimg is None or not dflimg.has_data():
            io.log_info(f'{filepath} is not a DFLIMG')
            continue
        
        img = cv2_imread(filepath).astype(np.float32) / 255.0
        h,w,c = img.shape
        
        img_face_type = FaceType.fromString( dflimg.get_face_type() )
        if face_type is not None and img_face_type != face_type:
            lmrks = dflimg.get_source_landmarks()
            
            fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type)
            imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type)
            
            g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True)
            g_p2 = LandmarksProcessor.transform_points (g_p, imat)
            
            mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) )
            
            img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4)
            img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
        else:
            if w != xseg_res:
                img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 )    
                    
        if len(img.shape) == 2:
            img = img[...,None]            
    
        mask = xseg.extract(img)
        
        if face_type is not None and img_face_type != face_type:
            mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4)
            mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4)
            mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4)
        mask[mask < 0.5]=0
        mask[mask >= 0.5]=1    
        dflimg.set_xseg_mask(mask)
        dflimg.save()
예제 #3
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if self.is_exporting or (len(
            device_config.devices) != 0 and not self.is_debug()) else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256

        self.face_type = {
            'h': FaceType.HALF,
            'mf': FaceType.MID_FULL,
            'f': FaceType.FULL,
            'wf': FaceType.WHOLE_FACE,
            'head': FaceType.HEAD
        }[self.options['face_type']]

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name='XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}'
                               if len(devices) != 0 else f'/CPU:0'):
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t, pretrain=self.pretrain)
                    gpu_pred_list.append(gpu_pred_t)

                    if self.pretrain:
                        # Structural loss
                        gpu_loss = tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 11.6)),
                            axis=[1])
                        gpu_loss += tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 23.2)),
                            axis=[1])
                        # Pixel loss
                        gpu_loss += tf.reduce_mean(
                            10 * tf.square(gpu_target_t - gpu_pred_t),
                            axis=[1, 2, 3])
                    else:
                        gpu_loss = tf.reduce_mean(
                            tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=gpu_target_t, logits=gpu_pred_logits_t),
                            axis=[1, 2, 3])

                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
            with tf.device(models_opt_device):
                pred = tf.concat(gpu_pred_list, 0)
                loss = tf.concat(gpu_losses, 0)
                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            if self.pretrain:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l
            else:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            if self.pretrain:
                pretrain_gen = SampleGeneratorFace(
                    self.get_pretraining_data_path(),
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.G,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    uniform_yaw_distribution=False,
                    generators_count=cpu_count)
                self.set_training_data_generators([pretrain_gen])
            else:
                srcdst_generator = SampleGeneratorFaceXSeg(
                    [self.training_data_src_path, self.training_data_dst_path],
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    resolution=resolution,
                    face_type=self.face_type,
                    generators_count=src_dst_generators_count,
                    data_format=nn.data_format)

                src_generator = SampleGeneratorFace(
                    self.training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=src_generators_count,
                    raise_on_no_data=False)
                dst_generator = SampleGeneratorFace(
                    self.training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=dst_generators_count,
                    raise_on_no_data=False)

                self.set_training_data_generators(
                    [srcdst_generator, src_generator, dst_generator])
예제 #4
0
class XSegModel(ModelBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, force_model_class_name='XSeg', **kwargs)

    #override
    def on_initialize_options(self):
        ask_override = self.ask_override()

        if not self.is_first_run() and ask_override:
            if io.input_bool(
                    f"Restart training?",
                    False,
                    help_message=
                    "Reset model weights and start training from scratch."):
                self.set_iter(0)

        default_face_type = self.options[
            'face_type'] = self.load_or_def_option('face_type', 'wf')
        default_pretrain = self.options['pretrain'] = self.load_or_def_option(
            'pretrain', False)

        if self.is_first_run():
            self.options['face_type'] = io.input_str(
                "Face type",
                default_face_type, ['h', 'mf', 'f', 'wf', 'head'],
                help_message=
                "Half / mid face / full face / whole face / head. Choose the same as your deepfake model."
            ).lower()

        if self.is_first_run() or ask_override:
            self.ask_batch_size(4, range=[2, 16])
            self.options['pretrain'] = io.input_bool("Enable pretraining mode",
                                                     default_pretrain)

        if not self.is_exporting and (
                self.options['pretrain']
                and self.get_pretraining_data_path() is None):
            raise Exception("pretraining_data_path is not defined")

        self.pretrain_just_disabled = (default_pretrain == True
                                       and self.options['pretrain'] == False)

    #override
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if self.is_exporting or (len(
            device_config.devices) != 0 and not self.is_debug()) else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256

        self.face_type = {
            'h': FaceType.HALF,
            'mf': FaceType.MID_FULL,
            'f': FaceType.FULL,
            'wf': FaceType.WHOLE_FACE,
            'head': FaceType.HEAD
        }[self.options['face_type']]

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name='XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        self.pretrain = self.options['pretrain']
        if self.pretrain_just_disabled:
            self.set_iter(0)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}'
                               if len(devices) != 0 else f'/CPU:0'):
                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t, pretrain=self.pretrain)
                    gpu_pred_list.append(gpu_pred_t)

                    if self.pretrain:
                        # Structural loss
                        gpu_loss = tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 11.6)),
                            axis=[1])
                        gpu_loss += tf.reduce_mean(
                            5 * nn.dssim(gpu_target_t,
                                         gpu_pred_t,
                                         max_val=1.0,
                                         filter_size=int(resolution / 23.2)),
                            axis=[1])
                        # Pixel loss
                        gpu_loss += tf.reduce_mean(
                            10 * tf.square(gpu_target_t - gpu_pred_t),
                            axis=[1, 2, 3])
                    else:
                        gpu_loss = tf.reduce_mean(
                            tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=gpu_target_t, logits=gpu_pred_logits_t),
                            axis=[1, 2, 3])

                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok
            with tf.device(models_opt_device):
                pred = tf.concat(gpu_pred_list, 0)
                loss = tf.concat(gpu_losses, 0)
                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            if self.pretrain:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l
            else:

                def train(input_np, target_np):
                    l, _ = nn.tf_sess.run(
                        [loss, loss_gv_op],
                        feed_dict={
                            self.model.input_t: input_np,
                            self.model.target_t: target_np
                        })
                    return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            if self.pretrain:
                pretrain_gen = SampleGeneratorFace(
                    self.get_pretraining_data_path(),
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=True),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': True,
                            'transform': True,
                            'channel_type': SampleProcessor.ChannelType.G,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    uniform_yaw_distribution=False,
                    generators_count=cpu_count)
                self.set_training_data_generators([pretrain_gen])
            else:
                srcdst_generator = SampleGeneratorFaceXSeg(
                    [self.training_data_src_path, self.training_data_dst_path],
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    resolution=resolution,
                    face_type=self.face_type,
                    generators_count=src_dst_generators_count,
                    data_format=nn.data_format)

                src_generator = SampleGeneratorFace(
                    self.training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=src_generators_count,
                    raise_on_no_data=False)
                dst_generator = SampleGeneratorFace(
                    self.training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.get_batch_size(),
                    sample_process_options=SampleProcessor.Options(
                        random_flip=False),
                    output_sample_types=[
                        {
                            'sample_type':
                            SampleProcessor.SampleType.FACE_IMAGE,
                            'warp': False,
                            'transform': False,
                            'channel_type': SampleProcessor.ChannelType.BGR,
                            'border_replicate': False,
                            'face_type': self.face_type,
                            'data_format': nn.data_format,
                            'resolution': resolution
                        },
                    ],
                    generators_count=dst_generators_count,
                    raise_on_no_data=False)

                self.set_training_data_generators(
                    [srcdst_generator, src_generator, dst_generator])

    #override
    def get_model_filename_list(self):
        return self.model.model_filename_list

    #override
    def onSave(self):
        self.model.save_weights()

    #override
    def onTrainOneIter(self):
        image_np, target_np = self.generate_next_samples()[0]
        loss = self.train(image_np, target_np)

        return (('loss', np.mean(loss)), )

    #override
    def onGetPreview(self, samples, for_history=False):
        n_samples = min(4, self.get_batch_size(), 800 // self.resolution)

        if self.pretrain:
            srcdst_samples, = samples
            image_np, mask_np = srcdst_samples
        else:
            srcdst_samples, src_samples, dst_samples = samples
            image_np, mask_np = srcdst_samples

        I, M, IM, = [
            np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0,
                    1.0) for x in ([image_np, mask_np] + self.view(image_np))
        ]
        M, IM, = [np.repeat(x, (3, ), -1) for x in [M, IM]]

        green_bg = np.tile(
            np.array([0, 1, 0], dtype=np.float32)[None, None, ...],
            (self.resolution, self.resolution, 1))

        result = []
        st = []
        for i in range(n_samples):
            if self.pretrain:
                ar = I[i], IM[i]
            else:
                ar = I[i] * M[i] + 0.5 * I[i] * (1 - M[i]) + 0.5 * green_bg * (
                    1 - M[i]), IM[i], I[i] * IM[i] + 0.5 * I[i] * (
                        1 - IM[i]) + 0.5 * green_bg * (1 - IM[i])
            st.append(np.concatenate(ar, axis=1))
        result += [
            ('XSeg training faces', np.concatenate(st, axis=0)),
        ]

        if not self.pretrain and len(src_samples) != 0:
            src_np, = src_samples

            D, DM, = [
                np.clip(nn.to_data_format(x, "NHWC", self.model_data_format),
                        0.0, 1.0) for x in ([src_np] + self.view(src_np))
            ]
            DM, = [np.repeat(x, (3, ), -1) for x in [DM]]

            st = []
            for i in range(n_samples):
                ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * (
                    1 - DM[i]) + 0.5 * green_bg * (1 - DM[i])
                st.append(np.concatenate(ar, axis=1))

            result += [
                ('XSeg src faces', np.concatenate(st, axis=0)),
            ]

        if not self.pretrain and len(dst_samples) != 0:
            dst_np, = dst_samples

            D, DM, = [
                np.clip(nn.to_data_format(x, "NHWC", self.model_data_format),
                        0.0, 1.0) for x in ([dst_np] + self.view(dst_np))
            ]
            DM, = [np.repeat(x, (3, ), -1) for x in [DM]]

            st = []
            for i in range(n_samples):
                ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * (
                    1 - DM[i]) + 0.5 * green_bg * (1 - DM[i])
                st.append(np.concatenate(ar, axis=1))

            result += [
                ('XSeg dst faces', np.concatenate(st, axis=0)),
            ]

        return result

    def export_dfm(self):
        output_path = self.get_strpath_storage_for_file(f'model.onnx')
        io.log_info(f'Dumping .onnx to {output_path}')
        tf = nn.tf

        with tf.device(nn.tf_default_device_name):
            input_t = tf.placeholder(
                nn.floatx, (None, self.resolution, self.resolution, 3),
                name='in_face')
            input_t = tf.transpose(input_t, (0, 3, 1, 2))
            _, pred_t = self.model.flow(input_t)
            pred_t = tf.transpose(pred_t, (0, 2, 3, 1))

        tf.identity(pred_t, name='out_mask')

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            nn.tf_sess,
            tf.get_default_graph().as_graph_def(), ['out_mask'])

        import tf2onnx
        with tf.device("/CPU:0"):
            model_proto, _ = tf2onnx.convert._convert_common(
                output_graph_def,
                name='XSeg',
                input_names=['in_face:0'],
                output_names=['out_mask:0'],
                opset=13,
                output_path=output_path)
예제 #5
0
class XSegModel(ModelBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, force_model_class_name='XSeg', **kwargs)

    #override
    def on_initialize_options(self):
        self.set_batch_size(4)

        ask_override = self.ask_override()

        if not self.is_first_run() and ask_override:
            if io.input_bool(
                    f"Restart training?",
                    False,
                    help_message=
                    "Reset model weights and start training from scratch."):
                self.set_iter(0)

        default_face_type = self.options[
            'face_type'] = self.load_or_def_option('face_type', 'wf')

        if self.is_first_run():
            self.options['face_type'] = io.input_str(
                "Face type",
                default_face_type, ['h', 'mf', 'f', 'wf', 'head'],
                help_message=
                "Half / mid face / full face / whole face / head. Choose the same as your deepfake model."
            ).lower()

    #override
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(
            device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256

        self.face_type = {
            'h': FaceType.HALF,
            'mf': FaceType.MID_FULL,
            'f': FaceType.FULL,
            'wf': FaceType.WHOLE_FACE,
            'head': FaceType.HEAD
        }[self.options['face_type']]

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name='XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t)
                    gpu_pred_list.append(gpu_pred_t)

                    gpu_loss = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=gpu_target_t, logits=gpu_pred_logits_t),
                        axis=[1, 2, 3])
                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred = nn.concat(gpu_pred_list, 0)
                loss = tf.reduce_mean(gpu_losses)

                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            def train(input_np, target_np):
                l, _ = nn.tf_sess.run([loss, loss_gv_op],
                                      feed_dict={
                                          self.model.input_t: input_np,
                                          self.model.target_t: target_np
                                      })
                return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            srcdst_generator = SampleGeneratorFaceXSeg(
                [self.training_data_src_path, self.training_data_dst_path],
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                resolution=resolution,
                face_type=self.face_type,
                generators_count=src_dst_generators_count,
                data_format=nn.data_format)

            src_generator = SampleGeneratorFace(
                self.training_data_src_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=src_generators_count,
                raise_on_no_data=False)
            dst_generator = SampleGeneratorFace(
                self.training_data_dst_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=dst_generators_count,
                raise_on_no_data=False)

            self.set_training_data_generators(
                [srcdst_generator, src_generator, dst_generator])

    #override
    def get_model_filename_list(self):
        return self.model.model_filename_list

    #override
    def onSave(self):
        self.model.save_weights()

    #override
    def onTrainOneIter(self):

        image_np, mask_np = self.generate_next_samples()[0]
        loss = self.train(image_np, mask_np)

        return (('loss', loss), )

    #override
    def onGetPreview(self, samples):
        n_samples = min(4, self.get_batch_size(), 800 // self.resolution)

        srcdst_samples, src_samples, dst_samples = samples
        image_np, mask_np = srcdst_samples

        I, M, IM, = [
            np.clip(nn.to_data_format(x, "NHWC", self.model_data_format), 0.0,
                    1.0) for x in ([image_np, mask_np] + self.view(image_np))
        ]
        M, IM, = [np.repeat(x, (3, ), -1) for x in [M, IM]]

        green_bg = np.tile(
            np.array([0, 1, 0], dtype=np.float32)[None, None, ...],
            (self.resolution, self.resolution, 1))

        result = []
        st = []
        for i in range(n_samples):
            ar = I[i] * M[i] + 0.5 * I[i] * (1 - M[i]) + 0.5 * green_bg * (
                1 - M[i]), IM[i], I[i] * IM[i] + 0.5 * I[i] * (
                    1 - IM[i]) + 0.5 * green_bg * (1 - IM[i])
            st.append(np.concatenate(ar, axis=1))
        result += [
            ('XSeg training faces', np.concatenate(st, axis=0)),
        ]

        if len(src_samples) != 0:
            src_np, = src_samples

            D, DM, = [
                np.clip(nn.to_data_format(x, "NHWC", self.model_data_format),
                        0.0, 1.0) for x in ([src_np] + self.view(src_np))
            ]
            DM, = [np.repeat(x, (3, ), -1) for x in [DM]]

            st = []
            for i in range(n_samples):
                ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * (
                    1 - DM[i]) + 0.5 * green_bg * (1 - DM[i])
                st.append(np.concatenate(ar, axis=1))

            result += [
                ('XSeg src faces', np.concatenate(st, axis=0)),
            ]

        if len(dst_samples) != 0:
            dst_np, = dst_samples

            D, DM, = [
                np.clip(nn.to_data_format(x, "NHWC", self.model_data_format),
                        0.0, 1.0) for x in ([dst_np] + self.view(dst_np))
            ]
            DM, = [np.repeat(x, (3, ), -1) for x in [DM]]

            st = []
            for i in range(n_samples):
                ar = D[i], DM[i], D[i] * DM[i] + 0.5 * D[i] * (
                    1 - DM[i]) + 0.5 * green_bg * (1 - DM[i])
                st.append(np.concatenate(ar, axis=1))

            result += [
                ('XSeg dst faces', np.concatenate(st, axis=0)),
            ]

        return result
예제 #6
0
    def on_initialize(self):
        device_config = nn.getCurrentDeviceConfig()
        self.model_data_format = "NCHW" if len(
            device_config.devices) != 0 and not self.is_debug() else "NHWC"
        nn.initialize(data_format=self.model_data_format)
        tf = nn.tf

        device_config = nn.getCurrentDeviceConfig()
        devices = device_config.devices

        self.resolution = resolution = 256
        self.face_type = FaceType.WHOLE_FACE

        place_model_on_cpu = len(devices) == 0
        models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'

        bgr_shape = nn.get4Dshape(resolution, resolution, 3)
        mask_shape = nn.get4Dshape(resolution, resolution, 1)

        # Initializing model classes
        self.model = XSegNet(name=f'XSeg',
                             resolution=resolution,
                             load_weights=not self.is_first_run(),
                             weights_file_root=self.get_model_root_path(),
                             training=True,
                             place_model_on_cpu=place_model_on_cpu,
                             optimizer=nn.RMSprop(lr=0.0001,
                                                  lr_dropout=0.3,
                                                  name='opt'),
                             data_format=nn.data_format)

        if self.is_training:
            # Adjust batch size for multiple GPU
            gpu_count = max(1, len(devices))
            bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
            self.set_batch_size(gpu_count * bs_per_gpu)

            # Compute losses per GPU
            gpu_pred_list = []

            gpu_losses = []
            gpu_loss_gvs = []

            for gpu_id in range(gpu_count):
                with tf.device(
                        f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0'):

                    with tf.device(f'/CPU:0'):
                        # slice on CPU, otherwise all batch data will be transfered to GPU first
                        batch_slice = slice(gpu_id * bs_per_gpu,
                                            (gpu_id + 1) * bs_per_gpu)
                        gpu_input_t = self.model.input_t[batch_slice, :, :, :]
                        gpu_target_t = self.model.target_t[
                            batch_slice, :, :, :]

                    # process model tensors
                    gpu_pred_logits_t, gpu_pred_t = self.model.flow(
                        gpu_input_t)
                    gpu_pred_list.append(gpu_pred_t)

                    gpu_loss = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=gpu_target_t, logits=gpu_pred_logits_t),
                        axis=[1, 2, 3])
                    gpu_losses += [gpu_loss]

                    gpu_loss_gvs += [
                        nn.gradients(gpu_loss, self.model.get_weights())
                    ]

            # Average losses and gradients, and create optimizer update ops
            with tf.device(models_opt_device):
                pred = nn.concat(gpu_pred_list, 0)
                loss = tf.reduce_mean(gpu_losses)

                loss_gv_op = self.model.opt.get_update_op(
                    nn.average_gv_list(gpu_loss_gvs))

            # Initializing training and view functions
            def train(input_np, target_np):
                l, _ = nn.tf_sess.run([loss, loss_gv_op],
                                      feed_dict={
                                          self.model.input_t: input_np,
                                          self.model.target_t: target_np
                                      })
                return l

            self.train = train

            def view(input_np):
                return nn.tf_sess.run([pred],
                                      feed_dict={self.model.input_t: input_np})

            self.view = view

            # initializing sample generators
            cpu_count = min(multiprocessing.cpu_count(), 8)
            src_dst_generators_count = cpu_count // 2
            src_generators_count = cpu_count // 2
            dst_generators_count = cpu_count // 2

            srcdst_generator = SampleGeneratorFaceXSeg(
                [self.training_data_src_path, self.training_data_dst_path],
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                resolution=resolution,
                face_type=self.face_type,
                generators_count=src_dst_generators_count,
                data_format=nn.data_format)

            src_generator = SampleGeneratorFace(
                self.training_data_src_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=src_generators_count,
                raise_on_no_data=False)
            dst_generator = SampleGeneratorFace(
                self.training_data_dst_path,
                debug=self.is_debug(),
                batch_size=self.get_batch_size(),
                sample_process_options=SampleProcessor.Options(
                    random_flip=False),
                output_sample_types=[
                    {
                        'sample_type': SampleProcessor.SampleType.FACE_IMAGE,
                        'warp': False,
                        'transform': False,
                        'channel_type': SampleProcessor.ChannelType.BGR,
                        'border_replicate': False,
                        'face_type': self.face_type,
                        'data_format': nn.data_format,
                        'resolution': resolution
                    },
                ],
                generators_count=dst_generators_count,
                raise_on_no_data=False)

            self.set_training_data_generators(
                [srcdst_generator, src_generator, dst_generator])