コード例 #1
0
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 debug=False,
                 device_args=None,
                 ask_enable_autobackup=True,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=True,
                                                **self.device_args)

        io.log_info("Loading model...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "Press enter in 2 seconds to override model settings.",
            5 if io.is_colab() else 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info(
                "\nModel first run. Enter model options as default for each run."
            )

        if ask_enable_autobackup and (self.iter == 0 or ask_override):
            default_autobackup = False if self.iter == 0 else self.options.get(
                'autobackup', False)
            self.options['autobackup'] = io.input_bool(
                "Enable autobackup? (y/n ?:help skip:%s) : " %
                (yn_str[default_autobackup]),
                default_autobackup,
                help_message=
                "Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01"
            )
        else:
            self.options['autobackup'] = self.options.get('autobackup', False)

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "Write preview history? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message=
                "Preview history will be writed to <ModelName>_history folder."
            )
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if (self.iter == 0 or ask_override) and self.options[
                'write_preview_history'] and io.is_support_windows():
            choose_preview_history = io.input_bool(
                "Choose image for the preview history? (y/n skip:%s) : " %
                (yn_str[False]), False)
        elif (self.iter == 0 or ask_override
              ) and self.options['write_preview_history'] and io.is_colab():
            choose_preview_history = io.input_bool(
                "Randomly choose new image for preview history? (y/n ?:help skip:%s) : "
                % (yn_str[False]),
                False,
                help_message=
                "Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person"
            )
        else:
            choose_preview_history = False

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "Target iteration (skip:unlimited/default) : ", 0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "Batch_size (?:help skip:%d) : " % (default_batch_size),
                    default_batch_size,
                    help_message=
                    "Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0 or ask_override):
                default_sort_by_yaw = self.options.get('sort_by_yaw', False)
                self.options['sort_by_yaw'] = io.input_bool(
                    "Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : "
                    % (yn_str[default_sort_by_yaw]),
                    default_sort_by_yaw,
                    help_message=
                    "NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw."
                )
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "Flip faces randomly? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset."
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "Src face scale modifier % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message=
                        "If src face shape is wider than dst, try to decrease this value to get a better result."
                    ), -30, 30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.autobackup = self.options.get('autobackup', False)
        if not self.autobackup and 'autobackup' in self.options:
            self.options.pop('autobackup')

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%s_autobackups' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%d_%s_autobackups' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or choose_preview_history:
                if choose_preview_history and io.is_support_windows():
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_sample()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_sample()
                self.last_sample = self.sample_for_preview

        ###Generate text summary of model hyperparameters
        #Find the longest key name and value string. Used as column widths.
        width_name = max(
            [len(k) for k in self.options.keys()] + [17]
        ) + 1  # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
        width_value = max([len(str(x)) for x in self.options.values()] +
                          [len(str(self.iter)),
                           len(self.get_model_name())]
                          ) + 1  # Single space buffer to right edge
        if not self.device_config.cpu_only:  #Check length of GPU names
            width_value = max([
                len(nnlib.device.getDeviceName(idx)) + 1
                for idx in self.device_config.gpu_idxs
            ] + [width_value])
        width_total = width_name + width_value + 2  #Plus 2 for ": "

        model_summary_text = []
        model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='
                               ]  # Model/status summary
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [
            f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='
        ]  # Name
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [
            f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='
        ]  # Iter
        model_summary_text += [f'=={" "*width_total}==']

        model_summary_text += [f'=={" Model Options ":-^{width_total}}=='
                               ]  # Model options
        model_summary_text += [f'=={" "*width_total}==']
        for key in self.options.keys():
            model_summary_text += [
                f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='
            ]  # self.options key/value pairs
        model_summary_text += [f'=={" "*width_total}==']

        model_summary_text += [f'=={" Running On ":-^{width_total}}=='
                               ]  # Training hardware info
        model_summary_text += [f'=={" "*width_total}==']
        if self.device_config.multi_gpu:
            model_summary_text += [
                f'=={"Using multi_gpu": >{width_name}}: {"True": <{width_value}}=='
            ]  # multi_gpu
            model_summary_text += [f'=={" "*width_total}==']
        if self.device_config.cpu_only:
            model_summary_text += [
                f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='
            ]  # cpu_only
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='
                ]  # GPU hardware device index
                model_summary_text += [
                    f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='
                ]  # GPU name
                vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB'  # GPU VRAM - Formated as #.## (or ##.##)
                model_summary_text += [
                    f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}=='
                ]
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [f'=={"="*width_total}==']

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] <= 2:  # Low VRAM warning
            model_summary_text += ["/!\\"]
            model_summary_text += ["/!\\ WARNING:"]
            model_summary_text += [
                "/!\\ You are using a GPU with 2GB or less VRAM. This may significantly reduce the quality of your result!"
            ]
            model_summary_text += [
                "/!\\ If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."
            ]
            model_summary_text += ["/!\\"]

        model_summary_text = "\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)
コード例 #2
0
ファイル: ModelBase.py プロジェクト: rohithreddy/DeepFaceLab
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 debug=False,
                 device_args=None,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=False,
                                                **self.device_args)

        io.log_info("Loading model...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data[
                    'loss_history'] if 'loss_history' in model_data.keys(
                    ) else []
                self.sample_for_preview = model_data[
                    'sample_for_preview'] if 'sample_for_preview' in model_data.keys(
                    ) else None

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "Press enter in 2 seconds to override model settings.", 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info(
                "\nModel first run. Enter model options as default for each run."
            )

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "Write preview history? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message=
                "Preview history will be writed to <ModelName>_history folder."
            )
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "Target iteration (skip:unlimited/default) : ", 0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "Batch_size (?:help skip:%d) : " % (default_batch_size),
                    default_batch_size,
                    help_message=
                    "Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0):
                self.options['sort_by_yaw'] = io.input_bool(
                    "Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ",
                    False,
                    help_message=
                    "NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw."
                )
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "Flip faces randomly? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset."
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "Src face scale modifier % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message=
                        "If src face shape is wider than dst, try to decrease this value to get a better result."
                    ), -30, 30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if (self.sample_for_preview is None) or (self.iter == 0):
                self.sample_for_preview = self.generate_next_sample()

        model_summary_text = []

        model_summary_text += ["===== Model summary ====="]
        model_summary_text += ["== Model name: " + self.get_model_name()]
        model_summary_text += ["=="]
        model_summary_text += ["== Current iteration: " + str(self.iter)]
        model_summary_text += ["=="]
        model_summary_text += ["== Model options:"]
        for key in self.options.keys():
            model_summary_text += ["== |== %s : %s" % (key, self.options[key])]

        if self.device_config.multi_gpu:
            model_summary_text += ["== |== multi_gpu : True "]

        model_summary_text += ["== Running on:"]
        if self.device_config.cpu_only:
            model_summary_text += ["== |== [CPU]"]
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    "== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))
                ]

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] == 2:
            model_summary_text += ["=="]
            model_summary_text += [
                "== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."
            ]
            model_summary_text += [
                "== If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "== Also you can disable Windows Aero Desktop to get extra free VRAM."
            ]
            model_summary_text += ["=="]

        model_summary_text += ["========================="]
        model_summary_text = "\r\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)
コード例 #3
0
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 debug=False,
                 device_args=None,
                 ask_enable_autobackup=True,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=False,
                                                **self.device_args)

        io.log_info("加载模型...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "\n2秒内按回车键[Enter]可以重新配置部分参数。\n\n", 5 if io.is_colab() else 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info("\n第一次启动模型. 请输入模型选项,当再次启动时会加载当前配置.\n")

        if ask_enable_autobackup and (self.iter == 0 or ask_override):
            default_autobackup = False if self.iter == 0 else self.options.get(
                'autobackup', False)
            self.options['autobackup'] = io.input_bool(
                "启动备份? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]),
                default_autobackup,
                help_message=
                "自动备份模型文件,过去15小时每小时备份一次。 位于model / <> _ autobackups /")
        else:
            self.options['autobackup'] = self.options.get('autobackup', False)

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "保存历史预览图[write_preview_history]? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message="预览图保存在<模型名称>_history文件夹。")
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if (self.iter == 0 or ask_override) and self.options[
                'write_preview_history'] and io.is_support_windows():
            choose_preview_history = io.input_bool(
                "选择预览图图片[write_preview_history]? (y/n skip:%s) : " %
                (yn_str[False]), False)
        else:
            choose_preview_history = False

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "目标迭代次数[Target iteration] (skip:unlimited/default) : ",
                        0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "批处理大小[Batch_size] (?:help skip:%d) : " %
                    (default_batch_size),
                    default_batch_size,
                    help_message=
                    "较大的批量大小更适合神经网络[NN]的泛化,但它可能导致内存不足[OOM]的错误。根据你显卡配置合理设置改选项,默认为4,推荐16."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0 or ask_override):
                default_sort_by_yaw = self.options.get('sort_by_yaw', False)
                self.options['sort_by_yaw'] = io.input_bool(
                    "根据侧脸排序[Feed faces to network sorted by yaw]? (y/n ?:help skip:%s) : "
                    % (yn_str[default_sort_by_yaw]),
                    default_sort_by_yaw,
                    help_message=
                    "神经网络[NN]不会学习与dst面部方向不匹配的src面部方向。 如果dst脸部有覆盖下颚的头发,请不要启用.")
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "随机反转[Flip faces randomly]? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "如果没有此选项,预测的脸部看起来会更自然,但源[src]的脸部集合[faceset]应覆盖所有面部方向,去陪陪目标[dst]的脸部集合[faceset]。"
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "源脸缩放[Src face scale modifier] % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message="如果src面部形状比dst宽,请尝试减小此值以获得更好的结果。"), -30,
                    30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.autobackup = self.options.get('autobackup', False)
        if not self.autobackup and 'autobackup' in self.options:
            self.options.pop('autobackup')

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%s_autobackups' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%d_%s_autobackups' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or choose_preview_history:
                if choose_preview_history and io.is_support_windows():
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_sample()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_sample()
                self.last_sample = self.sample_for_preview
        model_summary_text = []

        model_summary_text += ["\n===== 模型信息 =====\n"]
        model_summary_text += ["== 模型名称: " + self.get_model_name()]
        model_summary_text += ["=="]
        model_summary_text += ["== 当前迭代: " + str(self.iter)]
        model_summary_text += ["=="]
        model_summary_text += ["== 模型配置信息:"]
        for key in self.options.keys():
            model_summary_text += ["== |== %s : %s" % (key, self.options[key])]

        if self.device_config.multi_gpu:
            model_summary_text += ["== |== multi_gpu : True "]

        model_summary_text += ["== Running on:"]
        if self.device_config.cpu_only:
            model_summary_text += ["== |== [CPU]"]
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    "== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))
                ]

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] == 2:
            model_summary_text += ["=="]
            model_summary_text += [
                "== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."
            ]
            model_summary_text += [
                "== If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "== Also you can disable Windows Aero Desktop to get extra free VRAM."
            ]
            model_summary_text += ["=="]

        model_summary_text += ["========================="]
        model_summary_text = "\r\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)