コード例 #1
0
    def ask_write_preview_history(self, default_value=False):
        default_write_preview_history = self.load_or_def_option('write_preview_history', default_value)
        self.options['write_preview_history'] = io.input_bool(f"Write preview history", default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")

        if self.options['write_preview_history']:
            if io.is_support_windows():
                self.choose_preview_history = io.input_bool("Choose image for the preview history", False)
            elif io.is_colab():
                self.choose_preview_history = io.input_bool("Randomly choose new image for preview history", False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person")
コード例 #2
0
    def update_sample_for_preview(self,
                                  choose_preview_history=False,
                                  force_new=False):
        if self.sample_for_preview is None or choose_preview_history or force_new:
            if choose_preview_history and io.is_support_windows():
                wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm."
                io.log_info(
                    f"Choose image for the preview history. {wnd_name}")
                io.named_window(wnd_name)
                io.capture_keys(wnd_name)
                choosed = False
                preview_id_counter = 0
                while not choosed:
                    self.sample_for_preview = self.generate_next_samples()
                    previews = self.get_static_previews()

                    io.show_image(
                        wnd_name,
                        (previews[preview_id_counter % len(previews)][1] *
                         255).astype(np.uint8))

                    while True:
                        key_events = io.get_key_events(wnd_name)
                        key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                            -1] if len(key_events) > 0 else (0, 0, False,
                                                             False, False)
                        if key == ord('\n') or key == ord('\r'):
                            choosed = True
                            break
                        elif key == ord(' '):
                            preview_id_counter += 1
                            break
                        elif key == ord('p'):
                            break

                        try:
                            io.process_messages(0.1)
                        except KeyboardInterrupt:
                            choosed = True

                io.destroy_window(wnd_name)
            else:
                self.sample_for_preview = self.generate_next_samples()

        try:
            self.get_static_previews()
        except:
            self.sample_for_preview = self.generate_next_samples()

        self.last_sample = self.sample_for_preview
コード例 #3
0
ファイル: ModelBase.py プロジェクト: vkataev/DeepFaceLab
    def __init__(self,
                 is_training=False,
                 saved_models_path=None,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 pretrained_model_path=None,
                 no_preview=False,
                 force_model_name=None,
                 force_gpu_idxs=None,
                 cpu_only=False,
                 debug=False,
                 **kwargs):
        self.is_training = is_training
        self.saved_models_path = saved_models_path
        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path
        self.pretrained_model_path = pretrained_model_path
        self.no_preview = no_preview
        self.debug = debug

        self.model_class_name = model_class_name = Path(
            inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]

        if force_model_name is not None:
            self.model_name = force_model_name
        else:
            while True:
                # gather all model dat files
                saved_models_names = []
                for filepath in pathex.get_file_paths(saved_models_path):
                    filepath_name = filepath.name
                    if filepath_name.endswith(f'{model_class_name}_data.dat'):
                        saved_models_names += [(filepath_name.split('_')[0],
                                                os.path.getmtime(filepath))]

                # sort by modified datetime
                saved_models_names = sorted(saved_models_names,
                                            key=operator.itemgetter(1),
                                            reverse=True)
                saved_models_names = [x[0] for x in saved_models_names]

                if len(saved_models_names) != 0:
                    io.log_info(
                        "Choose one of saved models, or enter a name to create a new model."
                    )
                    io.log_info("[r] : rename")
                    io.log_info("[d] : delete")
                    io.log_info("")
                    for i, model_name in enumerate(saved_models_names):
                        s = f"[{i}] : {model_name} "
                        if i == 0:
                            s += "- latest"
                        io.log_info(s)

                    inp = io.input_str(f"", "0", show_default_value=False)
                    model_idx = -1
                    try:
                        model_idx = np.clip(int(inp), 0,
                                            len(saved_models_names) - 1)
                    except:
                        pass

                    if model_idx == -1:
                        if len(inp) == 1:
                            is_rename = inp[0] == 'r'
                            is_delete = inp[0] == 'd'

                            if is_rename or is_delete:
                                if len(saved_models_names) != 0:

                                    if is_rename:
                                        name = io.input_str(
                                            f"Enter the name of the model you want to rename"
                                        )
                                    elif is_delete:
                                        name = io.input_str(
                                            f"Enter the name of the model you want to delete"
                                        )

                                    if name in saved_models_names:

                                        if is_rename:
                                            new_model_name = io.input_str(
                                                f"Enter new name of the model")

                                        for filepath in pathex.get_paths(
                                                saved_models_path):
                                            filepath_name = filepath.name

                                            model_filename, remain_filename = filepath_name.split(
                                                '_', 1)
                                            if model_filename == name:

                                                if is_rename:
                                                    new_filepath = filepath.parent / (
                                                        new_model_name + '_' +
                                                        remain_filename)
                                                    filepath.rename(
                                                        new_filepath)
                                                elif is_delete:
                                                    filepath.unlink()
                                continue

                        self.model_name = inp
                    else:
                        self.model_name = saved_models_names[model_idx]

                else:
                    self.model_name = io.input_str(
                        f"No saved models found. Enter a name of a new model",
                        "noname")

                break

        self.model_name = self.model_name + '_' + self.model_class_name

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None
        self.choosed_gpu_indexes = None

        model_data = {}
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))
        if self.model_data_path.exists():
            io.log_info(f"Loading {self.model_name} model...")
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = model_data.get('iter', 0)
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)
                self.choosed_gpu_indexes = model_data.get(
                    'choosed_gpu_indexes', None)

        if self.is_first_run():
            io.log_info("\nModel first run.")

        self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
                             if not cpu_only else nn.DeviceConfig.CPU()

        nn.initialize(self.device_config)

        ####
        self.default_options_path = saved_models_path / f'{self.model_class_name}_default_options.dat'
        self.default_options = {}
        if self.default_options_path.exists():
            try:
                self.default_options = pickle.loads(
                    self.default_options_path.read_bytes())
            except:
                pass

        self.choose_preview_history = False
        self.batch_size = self.load_or_def_option('batch_size', 1)
        #####

        io.input_skip_pending()

        self.on_initialize_options()
        if self.is_first_run():
            # save as default options only for first run model initialize
            self.default_options_path.write_bytes(pickle.dumps(self.options))

        self.autobackup = self.options.get('autobackup', False)
        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        self.target_iter = self.options.get('target_iter', 0)
        self.random_flip = self.options.get('random_flip', True)

        self.on_initialize()
        self.options['batch_size'] = self.batch_size

        if self.is_training:
            self.preview_history_path = self.saved_models_path / (
                f'{self.get_model_name()}_history')
            self.autobackups_path = self.saved_models_path / (
                f'{self.get_model_name()}_autobackups')

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in pathex.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or self.choose_preview_history:
                if self.choose_preview_history and io.is_support_windows():
                    io.log_info(
                        "Choose image for the preview history. [p] - next. [enter] - confirm."
                    )
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_samples()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_samples()

            try:
                self.get_static_preview()
            except:
                self.sample_for_preview = self.generate_next_samples()

            self.last_sample = self.sample_for_preview

        io.log_info(self.get_summary_text())