Beispiel #1
0
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 debug=False,
                 device_args=None,
                 ask_enable_autobackup=True,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=True,
                                                **self.device_args)

        io.log_info("Loading model...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "Press enter in 2 seconds to override model settings.",
            5 if io.is_colab() else 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info(
                "\nModel first run. Enter model options as default for each run."
            )

        if ask_enable_autobackup and (self.iter == 0 or ask_override):
            default_autobackup = False if self.iter == 0 else self.options.get(
                'autobackup', False)
            self.options['autobackup'] = io.input_bool(
                "Enable autobackup? (y/n ?:help skip:%s) : " %
                (yn_str[default_autobackup]),
                default_autobackup,
                help_message=
                "Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01"
            )
        else:
            self.options['autobackup'] = self.options.get('autobackup', False)

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "Write preview history? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message=
                "Preview history will be writed to <ModelName>_history folder."
            )
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if (self.iter == 0 or ask_override) and self.options[
                'write_preview_history'] and io.is_support_windows():
            choose_preview_history = io.input_bool(
                "Choose image for the preview history? (y/n skip:%s) : " %
                (yn_str[False]), False)
        elif (self.iter == 0 or ask_override
              ) and self.options['write_preview_history'] and io.is_colab():
            choose_preview_history = io.input_bool(
                "Randomly choose new image for preview history? (y/n ?:help skip:%s) : "
                % (yn_str[False]),
                False,
                help_message=
                "Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person"
            )
        else:
            choose_preview_history = False

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "Target iteration (skip:unlimited/default) : ", 0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "Batch_size (?:help skip:%d) : " % (default_batch_size),
                    default_batch_size,
                    help_message=
                    "Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0 or ask_override):
                default_sort_by_yaw = self.options.get('sort_by_yaw', False)
                self.options['sort_by_yaw'] = io.input_bool(
                    "Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : "
                    % (yn_str[default_sort_by_yaw]),
                    default_sort_by_yaw,
                    help_message=
                    "NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw."
                )
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "Flip faces randomly? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset."
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "Src face scale modifier % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message=
                        "If src face shape is wider than dst, try to decrease this value to get a better result."
                    ), -30, 30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.autobackup = self.options.get('autobackup', False)
        if not self.autobackup and 'autobackup' in self.options:
            self.options.pop('autobackup')

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%s_autobackups' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%d_%s_autobackups' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or choose_preview_history:
                if choose_preview_history and io.is_support_windows():
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_sample()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_sample()
                self.last_sample = self.sample_for_preview

        ###Generate text summary of model hyperparameters
        #Find the longest key name and value string. Used as column widths.
        width_name = max(
            [len(k) for k in self.options.keys()] + [17]
        ) + 1  # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
        width_value = max([len(str(x)) for x in self.options.values()] +
                          [len(str(self.iter)),
                           len(self.get_model_name())]
                          ) + 1  # Single space buffer to right edge
        if not self.device_config.cpu_only:  #Check length of GPU names
            width_value = max([
                len(nnlib.device.getDeviceName(idx)) + 1
                for idx in self.device_config.gpu_idxs
            ] + [width_value])
        width_total = width_name + width_value + 2  #Plus 2 for ": "

        model_summary_text = []
        model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='
                               ]  # Model/status summary
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [
            f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='
        ]  # Name
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [
            f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='
        ]  # Iter
        model_summary_text += [f'=={" "*width_total}==']

        model_summary_text += [f'=={" Model Options ":-^{width_total}}=='
                               ]  # Model options
        model_summary_text += [f'=={" "*width_total}==']
        for key in self.options.keys():
            model_summary_text += [
                f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='
            ]  # self.options key/value pairs
        model_summary_text += [f'=={" "*width_total}==']

        model_summary_text += [f'=={" Running On ":-^{width_total}}=='
                               ]  # Training hardware info
        model_summary_text += [f'=={" "*width_total}==']
        if self.device_config.multi_gpu:
            model_summary_text += [
                f'=={"Using multi_gpu": >{width_name}}: {"True": <{width_value}}=='
            ]  # multi_gpu
            model_summary_text += [f'=={" "*width_total}==']
        if self.device_config.cpu_only:
            model_summary_text += [
                f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='
            ]  # cpu_only
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='
                ]  # GPU hardware device index
                model_summary_text += [
                    f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='
                ]  # GPU name
                vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB'  # GPU VRAM - Formated as #.## (or ##.##)
                model_summary_text += [
                    f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}=='
                ]
        model_summary_text += [f'=={" "*width_total}==']
        model_summary_text += [f'=={"="*width_total}==']

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] <= 2:  # Low VRAM warning
            model_summary_text += ["/!\\"]
            model_summary_text += ["/!\\ WARNING:"]
            model_summary_text += [
                "/!\\ You are using a GPU with 2GB or less VRAM. This may significantly reduce the quality of your result!"
            ]
            model_summary_text += [
                "/!\\ If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."
            ]
            model_summary_text += ["/!\\"]

        model_summary_text = "\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)
Beispiel #2
0
def main(args, device_args):
    io.log_info("Running trainer.\r\n")

    no_preview = args.get('no_preview', False)

    s2c = queue.Queue()
    c2s = queue.Queue()

    thread = threading.Thread(target=trainerThread,
                              args=(s2c, c2s, args, device_args))
    thread.start()

    if no_preview:
        while True:
            if not c2s.empty():
                input = c2s.get()
                op = input.get('op', '')
                if op == 'close':
                    break
            try:
                io.process_messages(0.1)
            except KeyboardInterrupt:
                s2c.put({'op': 'close'})
    else:
        wnd_name = "Training preview"
        io.named_window(wnd_name)
        io.capture_keys(wnd_name)

        previews = None
        loss_history = None
        selected_preview = 0
        update_preview = False
        is_showing = False
        is_waiting_preview = False
        show_last_history_iters_count = 0
        iter = 0
        while True:
            if not c2s.empty():
                input = c2s.get()
                op = input['op']
                if op == 'show':
                    is_waiting_preview = False
                    loss_history = input[
                        'loss_history'] if 'loss_history' in input.keys(
                        ) else None
                    previews = input['previews'] if 'previews' in input.keys(
                    ) else None
                    iter = input['iter'] if 'iter' in input.keys() else 0
                    if previews is not None:
                        max_w = 0
                        max_h = 0
                        for (preview_name, preview_rgb) in previews:
                            (h, w, c) = preview_rgb.shape
                            max_h = max(max_h, h)
                            max_w = max(max_w, w)

                        max_size = 800
                        if max_h > max_size:
                            max_w = int(max_w / (max_h / max_size))
                            max_h = max_size

                        #make all previews size equal
                        for preview in previews[:]:
                            (preview_name, preview_rgb) = preview
                            (h, w, c) = preview_rgb.shape
                            if h != max_h or w != max_w:
                                previews.remove(preview)
                                previews.append(
                                    (preview_name,
                                     cv2.resize(preview_rgb, (max_w, max_h))))
                        selected_preview = selected_preview % len(previews)
                        update_preview = True
                elif op == 'close':
                    break

            if update_preview:
                update_preview = False

                selected_preview_name = previews[selected_preview][0]
                selected_preview_rgb = previews[selected_preview][1]
                (h, w, c) = selected_preview_rgb.shape

                # HEAD
                head_lines = [
                    '[s]:save [enter]:exit',
                    '[p]:update [space]:next preview [l]:change history range',
                    'Preview: "%s" [%d/%d]' %
                    (selected_preview_name, selected_preview + 1,
                     len(previews))
                ]
                head_line_height = 15
                head_height = len(head_lines) * head_line_height
                head = np.ones((head_height, w, c)) * 0.1

                for i in range(0, len(head_lines)):
                    t = i * head_line_height
                    b = (i + 1) * head_line_height
                    head[t:b, 0:w] += imagelib.get_text_image(
                        (head_line_height, w, c),
                        head_lines[i],
                        color=[0.8] * c)

                final = head

                if loss_history is not None:
                    if show_last_history_iters_count == 0:
                        loss_history_to_show = loss_history
                    else:
                        loss_history_to_show = loss_history[
                            -show_last_history_iters_count:]

                    lh_img = models.ModelBase.get_loss_history_preview(
                        loss_history_to_show, iter, w, c)
                    final = np.concatenate([final, lh_img], axis=0)

                final = np.concatenate([final, selected_preview_rgb], axis=0)
                final = np.clip(final, 0, 1)

                io.show_image(wnd_name, (final * 255).astype(np.uint8))
                is_showing = True

            key_events = io.get_key_events(wnd_name)
            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                -1] if len(key_events) > 0 else (0, 0, False, False, False)

            if key == ord('\n') or key == ord('\r'):
                s2c.put({'op': 'close'})
            elif key == ord('s'):
                s2c.put({'op': 'save'})
            elif key == ord('p'):
                if not is_waiting_preview:
                    is_waiting_preview = True
                    s2c.put({'op': 'preview'})
            elif key == ord('l'):
                if show_last_history_iters_count == 0:
                    show_last_history_iters_count = 5000
                elif show_last_history_iters_count == 5000:
                    show_last_history_iters_count = 10000
                elif show_last_history_iters_count == 10000:
                    show_last_history_iters_count = 50000
                elif show_last_history_iters_count == 50000:
                    show_last_history_iters_count = 100000
                elif show_last_history_iters_count == 100000:
                    show_last_history_iters_count = 0
                update_preview = True
            elif key == ord(' '):
                selected_preview = (selected_preview + 1) % len(previews)
                update_preview = True

            try:
                io.process_messages(0.1)
            except KeyboardInterrupt:
                s2c.put({'op': 'close'})

        io.destroy_all_windows()
Beispiel #3
0
    def get_data(self, host_dict):
        if not self.manual:
            if len(self.input_data) > 0:
                return self.input_data.pop(0)
        else:
            skip_remaining = False
            allow_remark_faces = False
            while len(self.input_data) > 0:
                data = self.input_data[0]
                filename, faces = data
                is_frame_done = False
                go_to_prev_frame = False

                # Can we mark an image that already has a marked face?
                if allow_remark_faces:
                    allow_remark_faces = False
                    # If there was already a face then lock the rectangle to it until the mouse is clicked
                    if len(faces) > 0:
                        self.rect, self.landmarks = faces.pop()

                        self.rect_locked = True
                        self.redraw_needed = True
                        faces.clear()
                        self.rect_size = (self.rect[2] - self.rect[0]) / 2
                        self.x = (self.rect[0] + self.rect[2]) / 2
                        self.y = (self.rect[1] + self.rect[3]) / 2

                if len(faces) == 0:
                    if self.cache_original_image[0] == filename:
                        self.original_image = self.cache_original_image[1]
                    else:
                        self.original_image = cv2_imread(filename)
                        self.cache_original_image = (filename,
                                                     self.original_image)

                    (h, w, c) = self.original_image.shape
                    self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / (
                        h * (16.0 / 9.0))

                    if self.cache_image[0] == (h, w, c) + (self.view_scale,
                                                           filename):
                        self.image = self.cache_image[1]
                    else:
                        self.image = cv2.resize(self.original_image, (int(
                            w * self.view_scale), int(h * self.view_scale)),
                                                interpolation=cv2.INTER_LINEAR)
                        self.cache_image = ((h, w, c) +
                                            (self.view_scale, filename),
                                            self.image)

                    (h, w, c) = self.image.shape

                    sh = (0, 0, w, min(100, h))
                    if self.cache_text_lines_img[0] == sh:
                        self.text_lines_img = self.cache_text_lines_img[1]
                    else:
                        self.text_lines_img = (image_utils.get_draw_text_lines(
                            self.image, sh, [
                                'Match landmarks with face exactly. Click to confirm/unconfirm selection',
                                '[Enter] - confirm face landmarks and continue',
                                '[Space] - confirm as unmarked frame and continue',
                                '[Mouse wheel] - change rect',
                                '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
                                '[h] - hide this help'
                            ], (1, 1, 1)) * 255).astype(np.uint8)

                        self.cache_text_lines_img = (sh, self.text_lines_img)

                    while True:
                        new_x = self.x
                        new_y = self.y
                        new_rect_size = self.rect_size

                        mouse_events = io.get_mouse_events(self.wnd_name)
                        for ev in mouse_events:
                            (x, y, ev, flags) = ev
                            if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked:
                                mod = 1 if flags > 0 else -1
                                diff = 1 if new_rect_size <= 40 else np.clip(
                                    new_rect_size / 10, 1, 10)
                                new_rect_size = max(5,
                                                    new_rect_size + diff * mod)
                            elif ev == io.EVENT_LBUTTONDOWN:
                                self.rect_locked = not self.rect_locked
                                self.redraw_needed = True
                            elif not self.rect_locked:
                                new_x = np.clip(x, 0, w - 1) / self.view_scale
                                new_y = np.clip(y, 0, h - 1) / self.view_scale

                        key_events = io.get_key_events(self.wnd_name)
                        key, = key_events[-1] if len(key_events) > 0 else (0, )

                        if key == ord('\r') or key == ord('\n'):
                            faces.append([(self.rect), self.landmarks])
                            is_frame_done = True
                            break
                        elif key == ord(' '):
                            is_frame_done = True
                            break
                        elif key == ord('.'):
                            allow_remark_faces = True
                            # Only save the face if the rect is still locked
                            if self.rect_locked:
                                faces.append([(self.rect), self.landmarks])
                            is_frame_done = True
                            break
                        elif key == ord(',') and len(self.result) > 0:
                            # Only save the face if the rect is still locked
                            if self.rect_locked:
                                faces.append([(self.rect), self.landmarks])
                            go_to_prev_frame = True
                            break
                        elif key == ord('q'):
                            skip_remaining = True
                            break
                        elif key == ord('h'):
                            self.hide_help = not self.hide_help
                            break

                        if self.x != new_x or \
                           self.y != new_y or \
                           self.rect_size != new_rect_size or \
                           self.redraw_needed:
                            self.x = new_x
                            self.y = new_y
                            self.rect_size = new_rect_size

                            self.rect = (int(self.x - self.rect_size),
                                         int(self.y - self.rect_size),
                                         int(self.x + self.rect_size),
                                         int(self.y + self.rect_size))

                            return [filename, [self.rect]]

                        io.process_messages(0.0001)
                else:
                    is_frame_done = True

                if is_frame_done:
                    self.result.append(data)
                    self.input_data.pop(0)
                    io.progress_bar_inc(1)
                    self.redraw_needed = True
                    self.rect_locked = False
                elif go_to_prev_frame:
                    self.input_data.insert(0, self.result.pop())
                    io.progress_bar_inc(-1)
                    allow_remark_faces = True
                    self.redraw_needed = True
                    self.rect_locked = False
                elif skip_remaining:
                    if self.rect_locked:
                        faces.append([(self.rect), self.landmarks])
                    while len(self.input_data) > 0:
                        self.result.append(self.input_data.pop(0))
                        io.progress_bar_inc(1)

        return None
Beispiel #4
0
def mask_editor_main(input_dir,
                     confirmed_dir=None,
                     skipped_dir=None,
                     no_default_mask=False):
    input_path = Path(input_dir)

    confirmed_path = Path(confirmed_dir)
    skipped_path = Path(skipped_dir)

    if not input_path.exists():
        raise ValueError('Input directory not found. Please ensure it exists.')

    if not confirmed_path.exists():
        confirmed_path.mkdir(parents=True)

    if not skipped_path.exists():
        skipped_path.mkdir(parents=True)

    if not no_default_mask:
        eyebrows_expand_mod = np.clip(
            io.input_int(
                "Default eyebrows expand modifier? (0..400, skip:100) : ",
                100), 0, 400) / 100.0
    else:
        eyebrows_expand_mod = None

    wnd_name = "MaskEditor tool"
    io.named_window(wnd_name)
    io.capture_mouse(wnd_name)
    io.capture_keys(wnd_name)

    cached_images = {}

    image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]
    done_paths = []
    done_images_types = {}
    image_paths_total = len(image_paths)
    saved_ie_polys = IEPolys()
    zoom_factor = 1.0
    preview_images_count = 9
    target_wh = 256

    do_prev_count = 0
    do_save_move_count = 0
    do_save_count = 0
    do_skip_move_count = 0
    do_skip_count = 0

    def jobs_count():
        return do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count

    is_exit = False
    while not is_exit:

        if len(image_paths) > 0:
            filepath = image_paths.pop(0)
        else:
            filepath = None

        next_image_paths = image_paths[0:preview_images_count]
        next_image_paths_names = [path.name for path in next_image_paths]
        prev_image_paths = done_paths[-preview_images_count:]
        prev_image_paths_names = [path.name for path in prev_image_paths]

        for key in list(cached_images.keys()):
            if key not in prev_image_paths_names and \
               key not in next_image_paths_names:
                cached_images.pop(key)

        for paths in [prev_image_paths, next_image_paths]:
            for path in paths:
                if path.name not in cached_images:
                    cached_images[path.name] = cv2_imread(str(path)) / 255.0

        if filepath is not None:
            if filepath.suffix == '.png':
                dflimg = DFLPNG.load(str(filepath))
            elif filepath.suffix == '.jpg':
                dflimg = DFLJPG.load(str(filepath))
            else:
                dflimg = None

            if dflimg is None:
                io.log_err("%s is not a dfl image file" % (filepath.name))
                continue
            else:
                lmrks = dflimg.get_landmarks()
                ie_polys = dflimg.get_ie_polys()
                fanseg_mask = dflimg.get_fanseg_mask()

                if filepath.name in cached_images:
                    img = cached_images[filepath.name]
                else:
                    img = cached_images[filepath.name] = cv2_imread(
                        str(filepath)) / 255.0

                if fanseg_mask is not None:
                    mask = fanseg_mask
                else:
                    if no_default_mask:
                        mask = np.zeros((target_wh, target_wh, 3))
                    else:
                        mask = LandmarksProcessor.get_image_hull_mask(
                            img.shape,
                            lmrks,
                            eyebrows_expand_mod=eyebrows_expand_mod)
        else:
            img = np.zeros((target_wh, target_wh, 3))
            mask = np.ones((target_wh, target_wh, 3))
            ie_polys = None

        def get_status_lines_func():
            return [
                'Progress: %d / %d . Current file: %s' %
                (len(done_paths), image_paths_total,
                 str(filepath.name) if filepath is not None else "end"),
                '[Left mouse button] - mark include mask.',
                '[Right mouse button] - mark exclude mask.',
                '[Middle mouse button] - finish current poly.',
                '[Mouse wheel] - undo/redo poly or point. [+ctrl] - undo to begin/redo to end',
                '[r] - applies edits made to last saved image.',
                '[q] - prev image. [w] - skip and move to %s. [e] - save and move to %s. '
                % (skipped_path.name, confirmed_path.name),
                '[z] - prev image. [x] - skip. [c] - save. ',
                'hold [shift] - speed up the frame counter by 10.',
                '[-/+] - window zoom [esc] - quit',
            ]

        try:
            ed = MaskEditor(img,
                            [(done_images_types[name], cached_images[name])
                             for name in prev_image_paths_names],
                            [(0, cached_images[name])
                             for name in next_image_paths_names], mask,
                            ie_polys, get_status_lines_func)
        except Exception as e:
            print(e)
            continue

        next = False
        while not next:
            io.process_messages(0.005)

            if jobs_count() == 0:
                for (x, y, ev, flags) in io.get_mouse_events(wnd_name):
                    x, y = int(x / zoom_factor), int(y / zoom_factor)
                    ed.set_mouse_pos(x, y)
                    if filepath is not None:
                        if ev == io.EVENT_LBUTTONDOWN:
                            ed.mask_point(1)
                        elif ev == io.EVENT_RBUTTONDOWN:
                            ed.mask_point(0)
                        elif ev == io.EVENT_MBUTTONDOWN:
                            ed.mask_finish()
                        elif ev == io.EVENT_MOUSEWHEEL:
                            if flags & 0x80000000 != 0:
                                if flags & 0x8 != 0:
                                    ed.undo_to_begin_point()
                                else:
                                    ed.undo_point()
                            else:
                                if flags & 0x8 != 0:
                                    ed.redo_to_end_point()
                                else:
                                    ed.redo_point()

                for key, chr_key, ctrl_pressed, alt_pressed, shift_pressed in io.get_key_events(
                        wnd_name):
                    if chr_key == 'q' or chr_key == 'z':
                        do_prev_count = 1 if not shift_pressed else 10
                    elif chr_key == '-':
                        zoom_factor = np.clip(zoom_factor - 0.1, 0.1, 4.0)
                        ed.set_screen_changed()
                    elif chr_key == '+':
                        zoom_factor = np.clip(zoom_factor + 0.1, 0.1, 4.0)
                        ed.set_screen_changed()
                    elif key == 27:  #esc
                        is_exit = True
                        next = True
                        break
                    elif filepath is not None:
                        if chr_key == 'e':
                            saved_ie_polys = ed.ie_polys
                            do_save_move_count = 1 if not shift_pressed else 10
                        elif chr_key == 'c':
                            saved_ie_polys = ed.ie_polys
                            do_save_count = 1 if not shift_pressed else 10
                        elif chr_key == 'w':
                            do_skip_move_count = 1 if not shift_pressed else 10
                        elif chr_key == 'x':
                            do_skip_count = 1 if not shift_pressed else 10
                        elif chr_key == 'r' and saved_ie_polys != None:
                            ed.set_ie_polys(saved_ie_polys)

            if do_prev_count > 0:
                do_prev_count -= 1
                if len(done_paths) > 0:
                    if filepath is not None:
                        image_paths.insert(0, filepath)

                    filepath = done_paths.pop(-1)
                    done_images_types[filepath.name] = 0

                    if filepath.parent != input_path:
                        new_filename_path = input_path / filepath.name
                        filepath.rename(new_filename_path)
                        image_paths.insert(0, new_filename_path)
                    else:
                        image_paths.insert(0, filepath)

                    next = True
            elif filepath is not None:
                if do_save_move_count > 0:
                    do_save_move_count -= 1

                    ed.mask_finish()
                    dflimg.embed_and_set(
                        str(filepath),
                        ie_polys=ed.get_ie_polys(),
                        eyebrows_expand_mod=eyebrows_expand_mod)

                    done_paths += [confirmed_path / filepath.name]
                    done_images_types[filepath.name] = 2
                    filepath.rename(done_paths[-1])

                    next = True
                elif do_save_count > 0:
                    do_save_count -= 1

                    ed.mask_finish()
                    dflimg.embed_and_set(
                        str(filepath),
                        ie_polys=ed.get_ie_polys(),
                        eyebrows_expand_mod=eyebrows_expand_mod)

                    done_paths += [filepath]
                    done_images_types[filepath.name] = 2

                    next = True
                elif do_skip_move_count > 0:
                    do_skip_move_count -= 1

                    done_paths += [skipped_path / filepath.name]
                    done_images_types[filepath.name] = 1
                    filepath.rename(done_paths[-1])

                    next = True
                elif do_skip_count > 0:
                    do_skip_count -= 1

                    done_paths += [filepath]
                    done_images_types[filepath.name] = 1

                    next = True
            else:
                do_save_move_count = do_save_count = do_skip_move_count = do_skip_count = 0

            if jobs_count() == 0:
                if ed.switch_screen_changed():
                    screen = ed.make_screen()
                    if zoom_factor != 1.0:
                        h, w, c = screen.shape
                        screen = cv2.resize(
                            screen,
                            (int(w * zoom_factor), int(h * zoom_factor)))
                    io.show_image(wnd_name, screen)

        io.process_messages(0.005)

    io.destroy_all_windows()
Beispiel #5
0
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 debug=False,
                 device_args=None,
                 ask_enable_autobackup=True,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=True,
                                                **self.device_args)

        io.log_info("Loading model...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "Press enter in 2 seconds to override model settings.",
            5 if io.is_colab() else 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info(
                "\nModel first run. Enter model options as default for each run."
            )

        if ask_enable_autobackup and (self.iter == 0 or ask_override):
            default_autobackup = False if self.iter == 0 else self.options.get(
                'autobackup', False)
            self.options['autobackup'] = io.input_bool(
                "Enable autobackup? (y/n ?:help skip:%s) : " %
                (yn_str[default_autobackup]),
                default_autobackup,
                help_message=
                "Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01"
            )
        else:
            self.options['autobackup'] = self.options.get('autobackup', False)

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "Write preview history? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message=
                "Preview history will be writed to <ModelName>_history folder."
            )
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if (self.iter == 0 or ask_override) and self.options[
                'write_preview_history'] and io.is_support_windows():
            choose_preview_history = io.input_bool(
                "Choose image for the preview history? (y/n skip:%s) : " %
                (yn_str[False]), False)
        elif (self.iter == 0 or ask_override
              ) and self.options['write_preview_history'] and io.is_colab():
            choose_preview_history = io.input_bool(
                "Randomly choose new image for preview history? (y/n ?:help skip:%s) : "
                % (yn_str[False]),
                False,
                help_message=
                "Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person"
            )
        else:
            choose_preview_history = False

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "Target iteration (skip:unlimited/default) : ", 0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "Batch_size (?:help skip:%d) : " % (default_batch_size),
                    default_batch_size,
                    help_message=
                    "Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0 or ask_override):
                default_sort_by_yaw = self.options.get('sort_by_yaw', False)
                self.options['sort_by_yaw'] = io.input_bool(
                    "Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : "
                    % (yn_str[default_sort_by_yaw]),
                    default_sort_by_yaw,
                    help_message=
                    "NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw."
                )
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "Flip faces randomly? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset."
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "Src face scale modifier % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message=
                        "If src face shape is wider than dst, try to decrease this value to get a better result."
                    ), -30, 30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.autobackup = self.options.get('autobackup', False)
        if not self.autobackup and 'autobackup' in self.options:
            self.options.pop('autobackup')

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%s_autobackups' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%d_%s_autobackups' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or choose_preview_history:
                if choose_preview_history and io.is_support_windows():
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_sample()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_sample()
                self.last_sample = self.sample_for_preview
        model_summary_text = []

        model_summary_text += ["===== Model summary ====="]
        model_summary_text += ["== Model name: " + self.get_model_name()]
        model_summary_text += ["=="]
        model_summary_text += ["== Current iteration: " + str(self.iter)]
        model_summary_text += ["=="]
        model_summary_text += ["== Model options:"]
        for key in self.options.keys():
            model_summary_text += ["== |== %s : %s" % (key, self.options[key])]

        if self.device_config.multi_gpu:
            model_summary_text += ["== |== multi_gpu : True "]

        model_summary_text += ["== Running on:"]
        if self.device_config.cpu_only:
            model_summary_text += ["== |== [CPU]"]
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    "== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))
                ]

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] == 2:
            model_summary_text += ["=="]
            model_summary_text += [
                "== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."
            ]
            model_summary_text += [
                "== If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "== Also you can disable Windows Aero Desktop to get extra free VRAM."
            ]
            model_summary_text += ["=="]

        model_summary_text += ["========================="]
        model_summary_text = "\r\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)
Beispiel #6
0
    def get_data(self, host_dict):
        if not self.manual:
            if len(self.input_data) > 0:
                return self.input_data.pop(0)
        else:
            need_remark_face = False
            redraw_needed = False
            while len(self.input_data) > 0:
                data = self.input_data[0]
                filename, data_rects, data_landmarks = data.filename, data.rects, data.landmarks
                is_frame_done = False

                if need_remark_face:  # need remark image from input data that already has a marked face?
                    need_remark_face = False
                    if len(
                            data_rects
                    ) != 0:  # If there was already a face then lock the rectangle to it until the mouse is clicked
                        self.rect = data_rects.pop()
                        self.landmarks = data_landmarks.pop()
                        data_rects.clear()
                        data_landmarks.clear()
                        redraw_needed = True
                        self.rect_locked = True
                        self.rect_size = (self.rect[2] - self.rect[0]) / 2
                        self.x = (self.rect[0] + self.rect[2]) / 2
                        self.y = (self.rect[1] + self.rect[3]) / 2

                if len(data_rects) == 0:
                    if self.cache_original_image[0] == filename:
                        self.original_image = self.cache_original_image[1]
                    else:
                        self.original_image = imagelib.normalize_channels(
                            cv2_imread(filename), 3)

                        self.cache_original_image = (filename,
                                                     self.original_image)

                    (h, w, c) = self.original_image.shape
                    self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / (
                        h * (16.0 / 9.0))

                    if self.cache_image[0] == (h, w, c) + (self.view_scale,
                                                           filename):
                        self.image = self.cache_image[1]
                    else:
                        self.image = cv2.resize(self.original_image, (int(
                            w * self.view_scale), int(h * self.view_scale)),
                                                interpolation=cv2.INTER_LINEAR)
                        self.cache_image = ((h, w, c) +
                                            (self.view_scale, filename),
                                            self.image)

                    (h, w, c) = self.image.shape

                    sh = (0, 0, w, min(100, h))
                    if self.cache_text_lines_img[0] == sh:
                        self.text_lines_img = self.cache_text_lines_img[1]
                    else:
                        self.text_lines_img = (imagelib.get_draw_text_lines(
                            self.image, sh, [
                                '[Mouse click] - lock/unlock selection',
                                '[Mouse wheel] - change rect',
                                '[Enter] / [Space] - confirm / skip frame',
                                '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
                                '[a] - accuracy on/off (more fps)',
                                '[h] - hide this help'
                            ], (1, 1, 1)) * 255).astype(np.uint8)

                        self.cache_text_lines_img = (sh, self.text_lines_img)

                    while True:
                        io.process_messages(0.0001)

                        new_x = self.x
                        new_y = self.y
                        new_rect_size = self.rect_size

                        mouse_events = io.get_mouse_events(self.wnd_name)
                        for ev in mouse_events:
                            (x, y, ev, flags) = ev
                            if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked:
                                mod = 1 if flags > 0 else -1
                                diff = 1 if new_rect_size <= 40 else np.clip(
                                    new_rect_size / 10, 1, 10)
                                new_rect_size = max(5,
                                                    new_rect_size + diff * mod)
                            elif ev == io.EVENT_LBUTTONDOWN:
                                self.rect_locked = not self.rect_locked
                                self.extract_needed = True
                            elif not self.rect_locked:
                                new_x = np.clip(x, 0, w - 1) / self.view_scale
                                new_y = np.clip(y, 0, h - 1) / self.view_scale

                        key_events = io.get_key_events(self.wnd_name)
                        key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                            -1] if len(key_events) > 0 else (0, 0, False,
                                                             False, False)

                        if key == ord('\r') or key == ord('\n'):
                            #confirm frame
                            is_frame_done = True
                            data_rects.append(self.rect)
                            data_landmarks.append(self.landmarks)
                            break
                        elif key == ord(' '):
                            #confirm skip frame
                            is_frame_done = True
                            break
                        elif key == ord(',') and len(self.result) > 0:
                            #go prev frame

                            if self.rect_locked:
                                self.rect_locked = False
                                # Only save the face if the rect is still locked
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            self.input_data.insert(0, self.result.pop())
                            io.progress_bar_inc(-1)
                            need_remark_face = True

                            break
                        elif key == ord('.'):
                            #go next frame

                            if self.rect_locked:
                                self.rect_locked = False
                                # Only save the face if the rect is still locked
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            need_remark_face = True
                            is_frame_done = True
                            break
                        elif key == ord('q'):
                            #skip remaining

                            if self.rect_locked:
                                self.rect_locked = False
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            while len(self.input_data) > 0:
                                self.result.append(self.input_data.pop(0))
                                io.progress_bar_inc(1)

                            break

                        elif key == ord('h'):
                            self.hide_help = not self.hide_help
                            break
                        elif key == ord('a'):
                            self.landmarks_accurate = not self.landmarks_accurate
                            break

                        if self.x != new_x or \
                           self.y != new_y or \
                           self.rect_size != new_rect_size or \
                           self.extract_needed or \
                           redraw_needed:
                            self.x = new_x
                            self.y = new_y
                            self.rect_size = new_rect_size
                            self.rect = (int(self.x - self.rect_size),
                                         int(self.y - self.rect_size),
                                         int(self.x + self.rect_size),
                                         int(self.y + self.rect_size))

                            if redraw_needed:
                                redraw_needed = False
                                return ExtractSubprocessor.Data(
                                    filename,
                                    landmarks_accurate=self.landmarks_accurate)
                            else:
                                return ExtractSubprocessor.Data(
                                    filename,
                                    rects=[self.rect],
                                    landmarks_accurate=self.landmarks_accurate)

                else:
                    is_frame_done = True

                if is_frame_done:
                    self.result.append(data)
                    self.input_data.pop(0)
                    io.progress_bar_inc(1)
                    self.extract_needed = True
                    self.rect_locked = False

        return None
Beispiel #7
0
def main(args, device_args):
    io.log_info("Running trainer.\r\n")

    no_preview = args.get('no_preview', False)
    flask_preview = args.get('flask_preview', False)

    s2c = queue.Queue()
    c2s = queue.Queue()
    e = threading.Event()

    previews = None
    loss_history = None
    selected_preview = 0
    update_preview = False
    is_waiting_preview = False
    show_last_history_iters_count = 0
    iteration = 0
    batch_size = 1
    zoom = Zoom.ZOOM_100

    if flask_preview:
        from flaskr.app import create_flask_app
        s2flask = queue.Queue()
        socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args)

        e = threading.Event()
        thread = threading.Thread(target=trainer_thread,
                                  args=(s2c, c2s, e, args, device_args,
                                        socketio))
        thread.start()

        e.wait()  # Wait for inital load to occur.

        flask_t = threading.Thread(target=socketio.run,
                                   args=(flask_app, ),
                                   kwargs={
                                       'debug': True,
                                       'use_reloader': False
                                   })
        flask_t.start()

        while True:
            if not c2s.empty():
                item = c2s.get()
                op = item['op']
                if op == 'show':
                    is_waiting_preview = False
                    loss_history = item[
                        'loss_history'] if 'loss_history' in item.keys(
                        ) else None
                    previews = item['previews'] if 'previews' in item.keys(
                    ) else None
                    iteration = item['iter'] if 'iter' in item.keys() else 0
                    # batch_size = input['batch_size'] if 'iter' in input.keys() else 1
                    if previews is not None:
                        update_preview = True
                elif op == 'update':
                    if not is_waiting_preview:
                        is_waiting_preview = True
                    s2c.put({'op': 'preview'})
                elif op == 'next_preview':
                    selected_preview = (selected_preview + 1) % len(previews)
                    update_preview = True
                elif op == 'change_history_range':
                    if show_last_history_iters_count == 0:
                        show_last_history_iters_count = 5000
                    elif show_last_history_iters_count == 5000:
                        show_last_history_iters_count = 10000
                    elif show_last_history_iters_count == 10000:
                        show_last_history_iters_count = 50000
                    elif show_last_history_iters_count == 50000:
                        show_last_history_iters_count = 100000
                    elif show_last_history_iters_count == 100000:
                        show_last_history_iters_count = 0
                    update_preview = True
                elif op == 'close':
                    s2c.put({'op': 'close'})
                    break
                elif op == 'zoom_prev':
                    zoom = zoom.prev()
                    update_preview = True
                elif op == 'zoom_next':
                    zoom = zoom.next()
                    update_preview = True

            if update_preview:
                update_preview = False
                selected_preview = selected_preview % len(previews)
                preview_pane_image = create_preview_pane_image(
                    previews, selected_preview, loss_history,
                    show_last_history_iters_count, iteration, batch_size, zoom)
                # io.show_image(wnd_name, preview_pane_image)
                model_path = Path(args.get('model_path', ''))
                filename = 'preview.jpg'
                preview_file = str(model_path / filename)
                cv2.imwrite(preview_file, preview_pane_image)
                s2flask.put({'op': 'show'})
                socketio.emit('preview', {
                    'iter': iteration,
                    'loss': loss_history[-1]
                })
            try:
                io.process_messages(0.01)
            except KeyboardInterrupt:
                s2c.put({'op': 'close'})
    else:
        thread = threading.Thread(target=trainer_thread,
                                  args=(s2c, c2s, e, args, device_args))
        thread.start()

        e.wait()  # Wait for inital load to occur.

        if no_preview:
            while True:
                if not c2s.empty():
                    item = c2s.get()
                    op = item.get('op', '')
                    if op == 'close':
                        break
                try:
                    io.process_messages(0.1)
                except KeyboardInterrupt:
                    s2c.put({'op': 'close'})
        else:
            wnd_name = "Training preview"
            io.named_window(wnd_name)
            io.capture_keys(wnd_name)

            previews = None
            loss_history = None
            selected_preview = 0
            update_preview = False
            is_showing = False
            is_waiting_preview = False
            show_last_history_iters_count = 0
            iteration = 0
            batch_size = 1
            zoom = Zoom.ZOOM_100

            while True:
                if not c2s.empty():
                    item = c2s.get()
                    op = item['op']
                    if op == 'show':
                        is_waiting_preview = False
                        loss_history = item[
                            'loss_history'] if 'loss_history' in item.keys(
                            ) else None
                        previews = item['previews'] if 'previews' in item.keys(
                        ) else None
                        iteration = item['iter'] if 'iter' in item.keys(
                        ) else 0
                        # batch_size = input['batch_size'] if 'iter' in input.keys() else 1
                        if previews is not None:
                            update_preview = True
                    elif op == 'close':
                        break

                if update_preview:
                    update_preview = False
                    selected_preview = selected_preview % len(previews)
                    preview_pane_image = create_preview_pane_image(
                        previews, selected_preview, loss_history,
                        show_last_history_iters_count, iteration, batch_size,
                        zoom)
                    io.show_image(wnd_name, preview_pane_image)

                key_events = io.get_key_events(wnd_name)
                key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                    -1] if len(key_events) > 0 else (0, 0, False, False, False)

                if key == ord('\n') or key == ord('\r'):
                    s2c.put({'op': 'close'})
                elif key == ord('s'):
                    s2c.put({'op': 'save'})
                elif key == ord('p'):
                    if not is_waiting_preview:
                        is_waiting_preview = True
                        s2c.put({'op': 'preview'})
                elif key == ord('l'):
                    if show_last_history_iters_count == 0:
                        show_last_history_iters_count = 5000
                    elif show_last_history_iters_count == 5000:
                        show_last_history_iters_count = 10000
                    elif show_last_history_iters_count == 10000:
                        show_last_history_iters_count = 50000
                    elif show_last_history_iters_count == 50000:
                        show_last_history_iters_count = 100000
                    elif show_last_history_iters_count == 100000:
                        show_last_history_iters_count = 0
                    update_preview = True
                elif key == ord(' '):
                    selected_preview = (selected_preview + 1) % len(previews)
                    update_preview = True
                elif key == ord('-'):
                    zoom = zoom.prev()
                    update_preview = True
                elif key == ord('=') or key == ord('+'):
                    zoom = zoom.next()
                    update_preview = True
                try:
                    io.process_messages(0.1)
                except KeyboardInterrupt:
                    s2c.put({'op': 'close'})

            io.destroy_all_windows()
    def run(self):
        wnd_name = "Relighter"
        io.named_window(wnd_name)
        io.capture_keys(wnd_name)
        io.capture_mouse(wnd_name)

        zoom_factor = 1.0

        is_angle_editing = False

        is_exit = False
        while not is_exit:
            io.process_messages(0.0001)

            mouse_events = io.get_mouse_events(wnd_name)
            for ev in mouse_events:
                (x, y, ev, flags) = ev
                if ev == io.EVENT_LBUTTONDOWN:
                    is_angle_editing = True

                if ev == io.EVENT_LBUTTONUP:
                    is_angle_editing = False

                if is_angle_editing:
                    h, w, c = self.current_img_shape

                    alt, azi, inten = self.alt_azi_ar[self.alt_azi_cur]
                    alt = np.clip((0.5 - y / w) * 2.0, -1, 1) * 90
                    azi = np.clip((x / h - 0.5) * 2.0, -1, 1) * 90
                    self.alt_azi_ar[self.alt_azi_cur] = (alt, azi, inten)

                    self.set_screen_changed()

            key_events = io.get_key_events(wnd_name)
            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                -1] if len(key_events) > 0 else (0, 0, False, False, False)

            if key != 0:
                if chr_key == 'q':
                    self.pick_new_face()
                elif chr_key == 'w':
                    self.alt_azi_cur = np.clip(self.alt_azi_cur - 1, 0,
                                               len(self.alt_azi_ar) - 1)
                    self.set_screen_changed()
                elif chr_key == 'e':
                    self.alt_azi_cur = np.clip(self.alt_azi_cur + 1, 0,
                                               len(self.alt_azi_ar) - 1)
                    self.set_screen_changed()
                elif chr_key == 'r':
                    #add direction
                    self.alt_azi_ar += [[0, 0, 1.0]]
                    self.alt_azi_cur += 1
                    self.set_screen_changed()
                elif chr_key == 't':
                    if len(self.alt_azi_ar) > 1:
                        self.alt_azi_ar.pop(self.alt_azi_cur)
                        self.alt_azi_cur = np.clip(self.alt_azi_cur, 0,
                                                   len(self.alt_azi_ar) - 1)
                        self.set_screen_changed()
                elif chr_key == 'a':
                    alt, azi, inten = self.alt_azi_ar[self.alt_azi_cur]
                    inten = np.clip(inten - 0.1, 0.0, 1.0)
                    self.alt_azi_ar[self.alt_azi_cur] = (alt, azi, inten)
                    self.set_screen_changed()
                elif chr_key == 's':
                    alt, azi, inten = self.alt_azi_ar[self.alt_azi_cur]
                    inten = np.clip(inten + 0.1, 0.0, 1.0)
                    self.alt_azi_ar[self.alt_azi_cur] = (alt, azi, inten)
                    self.set_screen_changed()
                elif key == 27 or chr_key == '\r' or chr_key == '\n':  #esc
                    is_exit = True

            if self.switch_screen_changed():
                screen = self.make_screen()
                if zoom_factor != 1.0:
                    h, w, c = screen.shape
                    screen = cv2.resize(
                        screen, (int(w * zoom_factor), int(h * zoom_factor)))
                io.show_image(wnd_name, screen)

        io.destroy_window(wnd_name)

        return self.alt_azi_ar
Beispiel #9
0
def main(input_dir, output_dir):
    input_path = Path(input_dir)
    output_path = Path(output_dir)

    if not input_path.exists():
        raise ValueError('Input directory not found. Please ensure it exists.')

    if not output_path.exists():
        output_path.mkdir(parents=True)

    wnd_name = "Labeling tool"
    io.named_window(wnd_name)
    io.capture_mouse(wnd_name)
    io.capture_keys(wnd_name)

    #for filename in io.progress_bar_generator (Path_utils.get_image_paths(input_path), desc="Labeling"):
    for filename in Path_utils.get_image_paths(input_path):
        filepath = Path(filename)

        if filepath.suffix == '.png':
            dflimg = DFLPNG.load(str(filepath))
        elif filepath.suffix == '.jpg':
            dflimg = DFLJPG.load(str(filepath))
        else:
            dflimg = None

        if dflimg is None:
            io.log_err("%s is not a dfl image file" % (filepath.name))
            continue

        lmrks = dflimg.get_landmarks()
        lmrks_list = lmrks.tolist()
        orig_img = cv2_imread(str(filepath))
        h, w, c = orig_img.shape

        mask_orig = LandmarksProcessor.get_image_hull_mask(
            orig_img.shape, lmrks).astype(np.uint8)[:, :, 0]
        ero_dil_rate = w // 8
        mask_ero = cv2.erode(
            mask_orig,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                      (ero_dil_rate, ero_dil_rate)),
            iterations=1)
        mask_dil = cv2.dilate(mask_orig,
                              cv2.getStructuringElement(
                                  cv2.MORPH_ELLIPSE,
                                  (ero_dil_rate, ero_dil_rate)),
                              iterations=1)

        #mask_bg = np.zeros(orig_img.shape[:2],np.uint8)
        mask_bg = 1 - mask_dil
        mask_bgp = np.ones(orig_img.shape[:2],
                           np.uint8)  #default - all background possible
        mask_fg = np.zeros(orig_img.shape[:2], np.uint8)
        mask_fgp = np.zeros(orig_img.shape[:2], np.uint8)

        img = orig_img.copy()

        l_thick = 2

        def draw_4_lines(masks_out, pts, thickness=1):
            fgp, fg, bg, bgp = masks_out
            h, w = fg.shape

            fgp_pts = []
            fg_pts = np.array([pts[i:i + 2] for i in range(len(pts) - 1)])
            bg_pts = []
            bgp_pts = []

            for i in range(len(fg_pts)):
                a, b = line = fg_pts[i]

                ba = b - a
                v = ba / npl.norm(ba)

                ccpv = np.array([v[1], -v[0]])
                cpv = np.array([-v[1], v[0]])
                step = 1 / max(np.abs(cpv))

                fgp_pts.append(
                    np.clip(line + ccpv * step * thickness, 0,
                            w - 1).astype(np.int))
                bg_pts.append(
                    np.clip(line + cpv * step * thickness, 0,
                            w - 1).astype(np.int))
                bgp_pts.append(
                    np.clip(line + cpv * step * thickness * 2, 0,
                            w - 1).astype(np.int))

            fgp_pts = np.array(fgp_pts)
            bg_pts = np.array(bg_pts)
            bgp_pts = np.array(bgp_pts)

            cv2.polylines(fgp, fgp_pts, False, (1, ), thickness=thickness)
            cv2.polylines(fg, fg_pts, False, (1, ), thickness=thickness)
            cv2.polylines(bg, bg_pts, False, (1, ), thickness=thickness)
            cv2.polylines(bgp, bgp_pts, False, (1, ), thickness=thickness)

        def draw_lines(masks_steps, pts, thickness=1):
            lines = np.array([pts[i:i + 2] for i in range(len(pts) - 1)])

            for mask, step in masks_steps:
                h, w = mask.shape

                mask_lines = []
                for i in range(len(lines)):
                    a, b = line = lines[i]
                    ba = b - a
                    ba_len = npl.norm(ba)
                    if ba_len != 0:
                        v = ba / ba_len
                        pv = np.array([-v[1], v[0]])
                        pv_inv_max = 1 / max(np.abs(pv))
                        mask_lines.append(
                            np.clip(line + pv * pv_inv_max * thickness * step,
                                    0, w - 1).astype(np.int))
                    else:
                        mask_lines.append(np.array(line, dtype=np.int))
                cv2.polylines(mask,
                              mask_lines,
                              False, (1, ),
                              thickness=thickness)

        def draw_fill_convex(mask_out, pts, scale=1.0):
            hull = cv2.convexHull(np.array(pts))

            if scale != 1.0:
                pts_count = hull.shape[0]

                sum_x = np.sum(hull[:, 0, 0])
                sum_y = np.sum(hull[:, 0, 1])

                hull_center = np.array([sum_x / pts_count, sum_y / pts_count])
                hull = hull_center + (hull - hull_center) * scale
                hull = hull.astype(pts.dtype)
            cv2.fillConvexPoly(mask_out, hull, (1, ))

        def get_gc_mask_bgr(gc_mask):
            h, w = gc_mask.shape
            bgr = np.zeros((h, w, 3), dtype=np.uint8)

            bgr[gc_mask == 0] = (0, 0, 0)
            bgr[gc_mask == 1] = (255, 255, 255)
            bgr[gc_mask == 2] = (0, 0, 255)  #RED
            bgr[gc_mask == 3] = (0, 255, 0)  #GREEN
            return bgr

        def get_gc_mask_result(gc_mask):
            return np.where((gc_mask == 1) + (gc_mask == 3), 1,
                            0).astype(np.int)

        #convex inner of right chin to end of right eyebrow
        #draw_fill_convex ( mask_fgp, lmrks_list[8:17]+lmrks_list[26:27] )

        #convex inner of start right chin to right eyebrow
        #draw_fill_convex ( mask_fgp, lmrks_list[8:9]+lmrks_list[22:27] )

        #convex inner of nose
        draw_fill_convex(mask_fgp, lmrks[27:36])

        #convex inner of nose half
        draw_fill_convex(mask_fg, lmrks[27:36], scale=0.5)

        #left corner of mouth to left corner of nose
        #draw_lines ( [ (mask_fg,0),   ], lmrks_list[49:50]+lmrks_list[32:33], l_thick)

        #convex inner: right corner of nose to centers of eyebrows
        #draw_fill_convex ( mask_fgp, lmrks_list[35:36]+lmrks_list[19:20]+lmrks_list[24:25])

        #right corner of mouth to right corner of nose
        #draw_lines ( [ (mask_fg,0),   ], lmrks_list[54:55]+lmrks_list[35:36], l_thick)

        #left eye
        #draw_fill_convex ( mask_fg, lmrks_list[36:40] )
        #right eye
        #draw_fill_convex ( mask_fg, lmrks_list[42:48] )

        #right chin
        draw_lines([
            (mask_bg, 0),
            (mask_fg, -1),
        ], lmrks[8:17], l_thick)

        #left eyebrow center to right eyeprow center
        draw_lines([
            (mask_bg, -1),
            (mask_fg, 0),
        ], lmrks_list[19:20] + lmrks_list[24:25], l_thick)
        #        #draw_lines ( [ (mask_bg,-1), (mask_fg,0),   ], lmrks_list[24:25] + lmrks_list[19:17:-1], l_thick)

        #half right eyebrow to end of right chin
        draw_lines([
            (mask_bg, -1),
            (mask_fg, 0),
        ], lmrks_list[24:27] + lmrks_list[16:17], l_thick)

        #import code
        #code.interact(local=dict(globals(), **locals()))

        #compose mask layers
        gc_mask = np.zeros(orig_img.shape[:2], np.uint8)
        gc_mask[mask_bgp == 1] = 2
        gc_mask[mask_fgp == 1] = 3
        gc_mask[mask_bg == 1] = 0
        gc_mask[mask_fg == 1] = 1

        gc_bgr_before = get_gc_mask_bgr(gc_mask)

        #io.show_image (wnd_name, gc_mask )

        ##points, hierarcy = cv2.findContours(original_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        ##gc_mask = ( (1-erode_mask)*2 + erode_mask )# * dilate_mask
        #gc_mask = (1-erode_mask)*2 + erode_mask
        #cv2.addWeighted(
        #gc_mask = mask_0_27 + (1-mask_0_27)*2
        #
        ##import code
        ##code.interact(local=dict(globals(), **locals()))
        #
        #rect = (1,1,img.shape[1]-2,img.shape[0]-2)
        #
        #
        cv2.grabCut(img, gc_mask, None, np.zeros((1, 65), np.float64),
                    np.zeros((1, 65), np.float64), 5, cv2.GC_INIT_WITH_MASK)

        gc_bgr = get_gc_mask_bgr(gc_mask)
        gc_mask_result = get_gc_mask_result(gc_mask)
        gc_mask_result_1 = gc_mask_result[:, :, np.newaxis]

        #import code
        #code.interact(local=dict(globals(), **locals()))
        orig_img_gc_layers_masked = (0.5 * orig_img + 0.5 * gc_bgr).astype(
            np.uint8)
        orig_img_gc_before_layers_masked = (0.5 * orig_img +
                                            0.5 * gc_bgr_before).astype(
                                                np.uint8)

        pink_bg = np.full(orig_img.shape, (255, 0, 255), dtype=np.uint8)

        orig_img_result = orig_img * gc_mask_result_1
        orig_img_result_pinked = orig_img_result + pink_bg * (1 -
                                                              gc_mask_result_1)

        #io.show_image (wnd_name, blended_img)

        ##gc_mask, bgdModel, fgdModel =
        #
        #mask2 = np.where((gc_mask==1) + (gc_mask==3),255,0).astype('uint8')[:,:,np.newaxis]
        #mask2 = np.repeat(mask2, (3,), -1)
        #
        ##mask2 = np.where(gc_mask!=0,255,0).astype('uint8')
        #blended_img = orig_img #-\
        #              #0.3 * np.full(original_img.shape, (50,50,50)) * (1-mask_0_27)[:,:,np.newaxis]
        #              #0.3 * np.full(original_img.shape, (50,50,50)) * (1-dilate_mask)[:,:,np.newaxis] +\
        #              #0.3 * np.full(original_img.shape, (50,50,50)) * (erode_mask)[:,:,np.newaxis]
        #blended_img = np.clip(blended_img, 0, 255).astype(np.uint8)
        ##import code
        ##code.interact(local=dict(globals(), **locals()))
        orig_img_lmrked = orig_img.copy()
        LandmarksProcessor.draw_landmarks(orig_img_lmrked,
                                          lmrks,
                                          transparent_mask=True)

        screen = np.concatenate([
            orig_img_gc_before_layers_masked,
            orig_img_gc_layers_masked,
            orig_img,
            orig_img_lmrked,
            orig_img_result_pinked,
            orig_img_result,
        ],
                                axis=1)

        io.show_image(wnd_name, screen.astype(np.uint8))

        while True:
            io.process_messages()

            for (x, y, ev, flags) in io.get_mouse_events(wnd_name):
                pass
                #print (x,y,ev,flags)

            key_events = [ev for ev, in io.get_key_events(wnd_name)]
            for key in key_events:
                if key == ord('1'):
                    pass
                if key == ord('2'):
                    pass
                if key == ord('3'):
                    pass

            if ord(' ') in key_events:
                break

    import code
    code.interact(local=dict(globals(), **locals()))


#original_mask = np.ones(original_img.shape[:2],np.uint8)*2
#cv2.drawContours(original_mask, points, -1, (1,), 1)
Beispiel #10
0
    def __init__(self,
                 model_path,
                 training_data_src_path=None,
                 training_data_dst_path=None,
                 pretraining_data_path=None,
                 debug=False,
                 device_args=None,
                 ask_enable_autobackup=True,
                 ask_write_preview_history=True,
                 ask_target_iter=True,
                 ask_batch_size=True,
                 ask_sort_by_yaw=True,
                 ask_random_flip=True,
                 ask_src_scale_mod=True):

        device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
        device_args['cpu_only'] = device_args.get('cpu_only', False)

        if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
            idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
            if len(idxs_names_list) > 1:
                io.log_info("You have multi GPUs in a system: ")
                for idx, name in idxs_names_list:
                    io.log_info("[%d] : %s" % (idx, name))

                device_args['force_gpu_idx'] = io.input_int(
                    "Which GPU idx to choose? ( skip: best GPU ) : ", -1,
                    [x[0] for x in idxs_names_list])
        self.device_args = device_args

        self.device_config = nnlib.DeviceConfig(allow_growth=False,
                                                **self.device_args)

        io.log_info("加载模型...")

        self.model_path = model_path
        self.model_data_path = Path(
            self.get_strpath_storage_for_file('data.dat'))

        self.training_data_src_path = training_data_src_path
        self.training_data_dst_path = training_data_dst_path
        self.pretraining_data_path = pretraining_data_path

        self.src_images_paths = None
        self.dst_images_paths = None
        self.src_yaw_images_paths = None
        self.dst_yaw_images_paths = None
        self.src_data_generator = None
        self.dst_data_generator = None
        self.debug = debug
        self.is_training_mode = (training_data_src_path is not None
                                 and training_data_dst_path is not None)

        self.iter = 0
        self.options = {}
        self.loss_history = []
        self.sample_for_preview = None

        model_data = {}
        if self.model_data_path.exists():
            model_data = pickle.loads(self.model_data_path.read_bytes())
            self.iter = max(model_data.get('iter', 0),
                            model_data.get('epoch', 0))
            if 'epoch' in self.options:
                self.options.pop('epoch')
            if self.iter != 0:
                self.options = model_data['options']
                self.loss_history = model_data.get('loss_history', [])
                self.sample_for_preview = model_data.get(
                    'sample_for_preview', None)

        ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time(
            "\n2秒内按回车键[Enter]可以重新配置部分参数。\n\n", 5 if io.is_colab() else 2)

        yn_str = {True: 'y', False: 'n'}

        if self.iter == 0:
            io.log_info("\n第一次启动模型. 请输入模型选项,当再次启动时会加载当前配置.\n")

        if ask_enable_autobackup and (self.iter == 0 or ask_override):
            default_autobackup = False if self.iter == 0 else self.options.get(
                'autobackup', False)
            self.options['autobackup'] = io.input_bool(
                "启动备份? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]),
                default_autobackup,
                help_message=
                "自动备份模型文件,过去15小时每小时备份一次。 位于model / <> _ autobackups /")
        else:
            self.options['autobackup'] = self.options.get('autobackup', False)

        if ask_write_preview_history and (self.iter == 0 or ask_override):
            default_write_preview_history = False if self.iter == 0 else self.options.get(
                'write_preview_history', False)
            self.options['write_preview_history'] = io.input_bool(
                "保存历史预览图[write_preview_history]? (y/n ?:help skip:%s) : " %
                (yn_str[default_write_preview_history]),
                default_write_preview_history,
                help_message="预览图保存在<模型名称>_history文件夹。")
        else:
            self.options['write_preview_history'] = self.options.get(
                'write_preview_history', False)

        if (self.iter == 0 or ask_override) and self.options[
                'write_preview_history'] and io.is_support_windows():
            choose_preview_history = io.input_bool(
                "选择预览图图片[write_preview_history]? (y/n skip:%s) : " %
                (yn_str[False]), False)
        else:
            choose_preview_history = False

        if ask_target_iter:
            if (self.iter == 0 or ask_override):
                self.options['target_iter'] = max(
                    0,
                    io.input_int(
                        "目标迭代次数[Target iteration] (skip:unlimited/default) : ",
                        0))
            else:
                self.options['target_iter'] = max(
                    model_data.get('target_iter', 0),
                    self.options.get('target_epoch', 0))
                if 'target_epoch' in self.options:
                    self.options.pop('target_epoch')

        if ask_batch_size and (self.iter == 0 or ask_override):
            default_batch_size = 0 if self.iter == 0 else self.options.get(
                'batch_size', 0)
            self.options['batch_size'] = max(
                0,
                io.input_int(
                    "批处理大小[Batch_size] (?:help skip:%d) : " %
                    (default_batch_size),
                    default_batch_size,
                    help_message=
                    "较大的批量大小更适合神经网络[NN]的泛化,但它可能导致内存不足[OOM]的错误。根据你显卡配置合理设置改选项,默认为4,推荐16."
                ))
        else:
            self.options['batch_size'] = self.options.get('batch_size', 0)

        if ask_sort_by_yaw:
            if (self.iter == 0 or ask_override):
                default_sort_by_yaw = self.options.get('sort_by_yaw', False)
                self.options['sort_by_yaw'] = io.input_bool(
                    "根据侧脸排序[Feed faces to network sorted by yaw]? (y/n ?:help skip:%s) : "
                    % (yn_str[default_sort_by_yaw]),
                    default_sort_by_yaw,
                    help_message=
                    "神经网络[NN]不会学习与dst面部方向不匹配的src面部方向。 如果dst脸部有覆盖下颚的头发,请不要启用.")
            else:
                self.options['sort_by_yaw'] = self.options.get(
                    'sort_by_yaw', False)

        if ask_random_flip:
            if (self.iter == 0):
                self.options['random_flip'] = io.input_bool(
                    "随机反转[Flip faces randomly]? (y/n ?:help skip:y) : ",
                    True,
                    help_message=
                    "如果没有此选项,预测的脸部看起来会更自然,但源[src]的脸部集合[faceset]应覆盖所有面部方向,去陪陪目标[dst]的脸部集合[faceset]。"
                )
            else:
                self.options['random_flip'] = self.options.get(
                    'random_flip', True)

        if ask_src_scale_mod:
            if (self.iter == 0):
                self.options['src_scale_mod'] = np.clip(
                    io.input_int(
                        "源脸缩放[Src face scale modifier] % ( -30...30, ?:help skip:0) : ",
                        0,
                        help_message="如果src面部形状比dst宽,请尝试减小此值以获得更好的结果。"), -30,
                    30)
            else:
                self.options['src_scale_mod'] = self.options.get(
                    'src_scale_mod', 0)

        self.autobackup = self.options.get('autobackup', False)
        if not self.autobackup and 'autobackup' in self.options:
            self.options.pop('autobackup')

        self.write_preview_history = self.options.get('write_preview_history',
                                                      False)
        if not self.write_preview_history and 'write_preview_history' in self.options:
            self.options.pop('write_preview_history')

        self.target_iter = self.options.get('target_iter', 0)
        if self.target_iter == 0 and 'target_iter' in self.options:
            self.options.pop('target_iter')

        self.batch_size = self.options.get('batch_size', 0)
        self.sort_by_yaw = self.options.get('sort_by_yaw', False)
        self.random_flip = self.options.get('random_flip', True)

        self.src_scale_mod = self.options.get('src_scale_mod', 0)
        if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
            self.options.pop('src_scale_mod')

        self.onInitializeOptions(self.iter == 0, ask_override)

        nnlib.import_all(self.device_config)
        self.keras = nnlib.keras
        self.K = nnlib.keras.backend

        self.onInitialize()

        self.options['batch_size'] = self.batch_size

        if self.debug or self.batch_size == 0:
            self.batch_size = 1

        if self.is_training_mode:
            if self.device_args['force_gpu_idx'] == -1:
                self.preview_history_path = self.model_path / (
                    '%s_history' % (self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%s_autobackups' % (self.get_model_name()))
            else:
                self.preview_history_path = self.model_path / (
                    '%d_%s_history' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))
                self.autobackups_path = self.model_path / (
                    '%d_%s_autobackups' %
                    (self.device_args['force_gpu_idx'], self.get_model_name()))

            if self.autobackup:
                self.autobackup_current_hour = time.localtime().tm_hour

                if not self.autobackups_path.exists():
                    self.autobackups_path.mkdir(exist_ok=True)

            if self.write_preview_history or io.is_colab():
                if not self.preview_history_path.exists():
                    self.preview_history_path.mkdir(exist_ok=True)
                else:
                    if self.iter == 0:
                        for filename in Path_utils.get_image_paths(
                                self.preview_history_path):
                            Path(filename).unlink()

            if self.generator_list is None:
                raise ValueError('You didnt set_training_data_generators()')
            else:
                for i, generator in enumerate(self.generator_list):
                    if not isinstance(generator, SampleGeneratorBase):
                        raise ValueError(
                            'training data generator is not subclass of SampleGeneratorBase'
                        )

            if self.sample_for_preview is None or choose_preview_history:
                if choose_preview_history and io.is_support_windows():
                    wnd_name = "[p] - next. [enter] - confirm."
                    io.named_window(wnd_name)
                    io.capture_keys(wnd_name)
                    choosed = False
                    while not choosed:
                        self.sample_for_preview = self.generate_next_sample()
                        preview = self.get_static_preview()
                        io.show_image(wnd_name,
                                      (preview * 255).astype(np.uint8))

                        while True:
                            key_events = io.get_key_events(wnd_name)
                            key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                                -1] if len(key_events) > 0 else (0, 0, False,
                                                                 False, False)
                            if key == ord('\n') or key == ord('\r'):
                                choosed = True
                                break
                            elif key == ord('p'):
                                break

                            try:
                                io.process_messages(0.1)
                            except KeyboardInterrupt:
                                choosed = True

                    io.destroy_window(wnd_name)
                else:
                    self.sample_for_preview = self.generate_next_sample()
                self.last_sample = self.sample_for_preview
        model_summary_text = []

        model_summary_text += ["\n===== 模型信息 =====\n"]
        model_summary_text += ["== 模型名称: " + self.get_model_name()]
        model_summary_text += ["=="]
        model_summary_text += ["== 当前迭代: " + str(self.iter)]
        model_summary_text += ["=="]
        model_summary_text += ["== 模型配置信息:"]
        for key in self.options.keys():
            model_summary_text += ["== |== %s : %s" % (key, self.options[key])]

        if self.device_config.multi_gpu:
            model_summary_text += ["== |== multi_gpu : True "]

        model_summary_text += ["== Running on:"]
        if self.device_config.cpu_only:
            model_summary_text += ["== |== [CPU]"]
        else:
            for idx in self.device_config.gpu_idxs:
                model_summary_text += [
                    "== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))
                ]

        if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[
                0] == 2:
            model_summary_text += ["=="]
            model_summary_text += [
                "== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."
            ]
            model_summary_text += [
                "== If training does not start, close all programs and try again."
            ]
            model_summary_text += [
                "== Also you can disable Windows Aero Desktop to get extra free VRAM."
            ]
            model_summary_text += ["=="]

        model_summary_text += ["========================="]
        model_summary_text = "\r\n".join(model_summary_text)
        self.model_summary_text = model_summary_text
        io.log_info(model_summary_text)
Beispiel #11
0
 def get_key_events(self):
     return io.get_key_events(self.wnd_name)
Beispiel #12
0
def mask_editor_main(input_dir, confirmed_dir=None, skipped_dir=None):
    input_path = Path(input_dir)

    confirmed_path = Path(confirmed_dir)
    skipped_path = Path(skipped_dir)

    if not input_path.exists():
        raise ValueError('Input directory not found. Please ensure it exists.')

    if not confirmed_path.exists():
        confirmed_path.mkdir(parents=True)

    if not skipped_path.exists():
        skipped_path.mkdir(parents=True)

    wnd_name = "MaskEditor tool"
    io.named_window(wnd_name)
    io.capture_mouse(wnd_name)
    io.capture_keys(wnd_name)

    image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]
    done_paths = []

    image_paths_total = len(image_paths)

    do_prev_count = 0
    do_save_move_count = 0
    do_save_count = 0
    do_skip_move_count = 0
    do_skip_count = 0

    is_exit = False
    while not is_exit:

        if len(image_paths) > 0:
            filepath = image_paths.pop(0)
        else:
            filepath = None

        if filepath is not None:
            if filepath.suffix == '.png':
                dflimg = DFLPNG.load(str(filepath))
            elif filepath.suffix == '.jpg':
                dflimg = DFLJPG.load(str(filepath))
            else:
                dflimg = None

            if dflimg is None:
                io.log_err("%s is not a dfl image file" % (filepath.name))
                continue

            lmrks = dflimg.get_landmarks()
            ie_polys = dflimg.get_ie_polys()

            img = cv2_imread(str(filepath)) / 255.0
            mask = LandmarksProcessor.get_image_hull_mask(img.shape, lmrks)
        else:
            img = np.zeros((256, 256, 3))
            mask = np.ones((256, 256, 3))
            ie_polys = None

        def get_status_lines_func():
            return [
                'Progress: %d / %d . Current file: %s' %
                (len(done_paths), image_paths_total,
                 str(filepath.name) if filepath is not None else "end"),
                '[Left mouse button] - mark include mask.',
                '[Right mouse button] - mark exclude mask.',
                '[Middle mouse button] - finish current poly.',
                '[Mouse wheel] - undo/redo poly or point. [+ctrl] - undo to begin/redo to end',
                '[q] - prev image. [w] - skip and move to %s. [e] - save and move to %s. '
                % (skipped_path.name, confirmed_path.name),
                '[z] - prev image. [x] - skip. [c] - save. ',
                'hold [shift] - speed up the frame counter by 10.'
                '[esc] - quit'
            ]

        ed = MaskEditor(img, mask, ie_polys, get_status_lines_func)

        next = False
        while not next:
            io.process_messages(0.005)

            if do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count == 0:
                for (x, y, ev, flags) in io.get_mouse_events(wnd_name):
                    ed.set_mouse_pos(x, y)
                    if filepath is not None:
                        if ev == io.EVENT_LBUTTONDOWN:
                            ed.mask_point(1)
                        elif ev == io.EVENT_RBUTTONDOWN:
                            ed.mask_point(0)
                        elif ev == io.EVENT_MBUTTONDOWN:
                            ed.mask_finish()
                        elif ev == io.EVENT_MOUSEWHEEL:
                            if flags & 0x80000000 != 0:
                                if flags & 0x8 != 0:
                                    ed.undo_to_begin_point()
                                else:
                                    ed.undo_point()
                            else:
                                if flags & 0x8 != 0:
                                    ed.redo_to_end_point()
                                else:
                                    ed.redo_point()

                for key, chr_key, ctrl_pressed, alt_pressed, shift_pressed in io.get_key_events(
                        wnd_name):
                    if chr_key == 'q' or chr_key == 'z':
                        do_prev_count = 1 if not shift_pressed else 10
                    elif key == 27:  #esc
                        is_exit = True
                        next = True
                        break
                    elif filepath is not None:
                        if chr_key == 'e':
                            do_save_move_count = 1 if not shift_pressed else 10
                        elif chr_key == 'c':
                            do_save_count = 1 if not shift_pressed else 10
                        elif chr_key == 'w':
                            do_skip_move_count = 1 if not shift_pressed else 10
                        elif chr_key == 'x':
                            do_skip_count = 1 if not shift_pressed else 10

            if do_prev_count > 0:
                do_prev_count -= 1
                if len(done_paths) > 0:
                    image_paths.insert(0, filepath)
                    filepath = done_paths.pop(-1)

                    if filepath.parent != input_path:
                        new_filename_path = input_path / filepath.name
                        filepath.rename(new_filename_path)
                        image_paths.insert(0, new_filename_path)
                    else:
                        image_paths.insert(0, filepath)

                    next = True
            elif filepath is not None:
                if do_save_move_count > 0:
                    do_save_move_count -= 1

                    ed.mask_finish()
                    dflimg.embed_and_set(str(filepath),
                                         ie_polys=ed.get_ie_polys())

                    done_paths += [confirmed_path / filepath.name]
                    filepath.rename(done_paths[-1])

                    next = True
                elif do_save_count > 0:
                    do_save_count -= 1

                    ed.mask_finish()
                    dflimg.embed_and_set(str(filepath),
                                         ie_polys=ed.get_ie_polys())

                    done_paths += [filepath]

                    next = True
                elif do_skip_move_count > 0:
                    do_skip_move_count -= 1

                    done_paths += [skipped_path / filepath.name]
                    filepath.rename(done_paths[-1])

                    next = True
                elif do_skip_count > 0:
                    do_skip_count -= 1

                    done_paths += [filepath]

                    next = True
            else:
                do_save_move_count = do_save_count = do_skip_move_count = do_skip_count = 0

            if do_prev_count + do_save_move_count + do_save_count + do_skip_move_count + do_skip_count == 0:
                if ed.switch_screen_changed():
                    io.show_image(wnd_name, ed.make_screen())

        io.process_messages(0.005)

    io.destroy_all_windows()
Beispiel #13
0
    def get_data(self, host_dict):
        if not self.manual:
            if len(self.input_data) > 0:
                return self.input_data.pop(0)
        else:
            need_remark_face = False
            redraw_needed = False
            while len(self.input_data) > 0:
                data = self.input_data[0]
                filename, data_rects, data_landmarks = data.filename, data.rects, data.landmarks
                is_frame_done = False

                if need_remark_face:  # need remark image from input data that already has a marked face?
                    need_remark_face = False
                    if len(
                            data_rects
                    ) != 0:  # If there was already a face then lock the rectangle to it until the mouse is clicked
                        self.rect = data_rects.pop()
                        self.landmarks = data_landmarks.pop()
                        data_rects.clear()
                        data_landmarks.clear()
                        redraw_needed = True
                        self.rect_locked = True
                        self.rect_size = (self.rect[2] - self.rect[0]) / 2
                        self.x = (self.rect[0] + self.rect[2]) / 2
                        self.y = (self.rect[1] + self.rect[3]) / 2

                if len(data_rects) == 0:
                    if self.cache_original_image[0] == filename:
                        self.original_image = self.cache_original_image[1]
                    else:
                        self.original_image = cv2_imread(filename)
                        self.cache_original_image = (filename,
                                                     self.original_image)

                    (h, w, c) = self.original_image.shape
                    self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / (
                        h * (16.0 / 9.0))

                    if self.cache_image[0] == (h, w, c) + (self.view_scale,
                                                           filename):
                        self.image = self.cache_image[1]
                    else:
                        self.image = cv2.resize(self.original_image, (int(
                            w * self.view_scale), int(h * self.view_scale)),
                                                interpolation=cv2.INTER_LINEAR)
                        self.cache_image = ((h, w, c) +
                                            (self.view_scale, filename),
                                            self.image)

                    (h, w, c) = self.image.shape

                    sh = (0, 0, w, min(100, h))
                    if self.cache_text_lines_img[0] == sh:
                        self.text_lines_img = self.cache_text_lines_img[1]
                    else:
                        self.text_lines_img = (imagelib.get_draw_text_lines(
                            self.image, sh, [
                                '[Mouse click] - lock/unlock selection',
                                '[Mouse wheel] - change rect',
                                '[Enter] / [Space] - confirm / skip frame',
                                '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
                                '[a] - accuracy on/off (more fps)',
                                '[f] - select face by last rect',
                                '[h] - hide this help'
                            ], (1, 1, 1)) * 255).astype(np.uint8)

                        self.cache_text_lines_img = (sh, self.text_lines_img)

                    while True:
                        io.process_messages(0.0001)

                        new_x = self.x
                        new_y = self.y
                        new_rect_size = self.rect_size

                        right_btn_down = False
                        mouse_events = io.get_mouse_events(self.wnd_name)
                        for ev in mouse_events:
                            (x, y, ev, flags) = ev
                            if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked:
                                mod = 1 if flags > 0 else -1
                                diff = 1 if new_rect_size <= 40 else np.clip(
                                    new_rect_size / 10, 1, 10)
                                new_rect_size = max(5,
                                                    new_rect_size + diff * mod)
                            elif ev == io.EVENT_LBUTTONDOWN:
                                self.rect_locked = not self.rect_locked
                                self.extract_needed = True
                            elif ev == io.EVENT_RBUTTONDOWN:
                                right_btn_down = True
                            elif not self.rect_locked:
                                new_x = np.clip(x, 0, w - 1) / self.view_scale
                                new_y = np.clip(y, 0, h - 1) / self.view_scale

                        key_events = io.get_key_events(self.wnd_name)
                        key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[
                            -1] if len(key_events) > 0 else (0, 0, False,
                                                             False, False)

                        if (key == ord('f')
                                or right_btn_down) and self.rect_locked:
                            # confirm frame
                            is_frame_done = True
                            self.last_outer = self.temp_outer
                            data_rects.append(self.rect)
                            data_landmarks.append(self.landmarks)
                            self.auto = True
                            break
                        elif (key == ord('f') or key == ord('s')
                              or self.auto) and len(self.last_outer) != 0:
                            last_mid = F.mid_point(self.last_outer)
                            last_border = np.linalg.norm(
                                np.array(self.last_outer[0]) -
                                np.array(self.last_outer[1]))
                            last_area = F.poly_area(self.last_outer)
                            x, y = last_mid
                            new_x = np.clip(x, 0, w - 1) / self.view_scale
                            new_y = np.clip(y, 0, h - 1) / self.view_scale
                            new_rect_size = last_border / 2 / self.view_scale * 0.8
                            # make sure rect and landmarks have been refreshed
                            # if self.x == new_x and self.y == new_y and len(self.temp_outer) != 0:
                            if len(self.temp_outer) != 0:
                                # compare dist and area
                                temp_mid = F.mid_point(self.temp_outer)
                                dist = np.linalg.norm(
                                    np.array(temp_mid) - np.array(last_mid))
                                dist_r = dist / last_border
                                temp_area = F.poly_area(self.temp_outer)
                                area_r = temp_area / last_area
                                v0 = np.array(last_mid) - np.array(
                                    self.last_outer[0])
                                v1 = np.array(temp_mid) - np.array(
                                    self.temp_outer[0])
                                angle = math.fabs(F.angle_between(v0, v1))
                                if dist_r < 0.5 and 0.5 < area_r < 1.5 and angle < 0.7:
                                    is_frame_done = True
                                    self.last_outer = self.temp_outer
                                    data_rects.append(self.rect)
                                    data_landmarks.append(self.landmarks)
                                    self.auto = True
                                    break
                                elif key == ord('s'):
                                    is_frame_done = True
                                    break
                                elif self.x != new_x or self.y != new_y:
                                    # 可以在等一轮更新后试一下
                                    pass
                                else:
                                    self.auto = False
                                    for i in range(3):
                                        time.sleep(0.1)
                                        print('\a')
                        elif key == ord('\r') or key == ord('\n'):
                            # confirm frame
                            is_frame_done = True
                            data_rects.append(self.rect)
                            data_landmarks.append(self.landmarks)
                            break
                        elif key == ord(' '):
                            #confirm skip frame
                            is_frame_done = True
                            break
                        elif key == ord('z') and len(self.result) > 0:
                            #go prev frame

                            if self.rect_locked:
                                self.rect_locked = False
                                # Only save the face if the rect is still locked
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            self.input_data.insert(0, self.result.pop())
                            io.progress_bar_inc(-1)
                            need_remark_face = True

                            break
                        elif key == ord('.'):
                            #go next frame

                            if self.rect_locked:
                                self.rect_locked = False
                                # Only save the face if the rect is still locked
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            need_remark_face = True
                            is_frame_done = True
                            break
                        elif key == ord('q'):
                            #skip remaining

                            if self.rect_locked:
                                self.rect_locked = False
                                data_rects.append(self.rect)
                                data_landmarks.append(self.landmarks)

                            while len(self.input_data) > 0:
                                self.result.append(self.input_data.pop(0))
                                io.progress_bar_inc(1)

                            break

                        elif key == ord('h'):
                            self.hide_help = not self.hide_help
                            break
                        elif key == ord('a'):
                            self.landmarks_accurate = not self.landmarks_accurate
                            break

                        if self.x != new_x or \
                           self.y != new_y or \
                           self.rect_size != new_rect_size or \
                           self.extract_needed or \
                           redraw_needed:
                            self.x = new_x
                            self.y = new_y
                            self.rect_size = new_rect_size
                            self.rect = (int(self.x - self.rect_size),
                                         int(self.y - self.rect_size),
                                         int(self.x + self.rect_size),
                                         int(self.y + self.rect_size))

                            if redraw_needed:
                                redraw_needed = False
                                return ExtractSubprocessor.Data(
                                    filename,
                                    landmarks_accurate=self.landmarks_accurate)
                            else:
                                return ExtractSubprocessor.Data(
                                    filename,
                                    rects=[self.rect],
                                    landmarks_accurate=self.landmarks_accurate)

                else:
                    is_frame_done = True

                if is_frame_done:
                    self.result.append(data)
                    self.input_data.pop(0)
                    io.progress_bar_inc(1)
                    self.extract_needed = True
                    self.rect_locked = False
                    self.temp_outer = []

        return None