コード例 #1
0
def detect_img_folder(img_folder, out_folder, yolo):
    mkdir_if_not_exist(out_folder)
    path_list, name_list = traverse_dir_files(img_folder)
    print_info('图片数: %s' % len(path_list))

    _, imgs_names = traverse_dir_files(out_folder)

    count = 0
    for path, name in zip(path_list, name_list):
        if path.endswith('.gif'):
            continue

        out_name = name + '.d.jpg'
        if out_name in imgs_names:
            print_info('已检测: %s' % name)
            continue

        print_info('检测图片: %s' % name)

        try:
            image = Image.open(path)
            out_file = os.path.join(ROOT_DIR, 'face', 'yolov3', 'output_data',
                                    'logAll_res.txt')
            r_image = yolo.detect_image(image, ('logAll/' + name), out_file)
            r_image.save(os.path.join(out_folder, name + '.d.jpg'))
        except Exception as e:
            print(e)
            pass

        count += 1
        if count % 100 == 0:
            print_info('已检测: %s' % count)
    yolo.close_session()
コード例 #2
0
ファイル: csv_reader.py プロジェクト: SpikeKing/XX-ImageLabel
def process_csv(file_name):
    data_lines, tag_dict = read_csv(file_name)

    all_file = os.path.join(TXT_DATA, 'all_raws')
    create_file(all_file)

    for data_line in data_lines:
        cid, tags, content = data_line
        write_line(all_file, cid + u'---' + tags + u'---' + content)
        # seg_list = cut_sentence(content)
        # print(seg_list)
        # if seg_list:
        #     write_line(all_file, ' '.join(seg_list))

    tags_folder = os.path.join(TXT_DATA, 'raws')
    mkdir_if_not_exist(tags_folder, is_delete=True)
    for tag in tag_dict.keys():
        tag_file = os.path.join(tags_folder, tag)
        feed_dict = dict()
        for data_feed in tag_dict[tag]:
            (feed_id, content) = data_feed
            if feed_id in feed_dict:
                print('重复 ID {}'.format(feed_id))
            feed_dict[feed_id] = content
        for feed_id in feed_dict.keys():
            content = feed_dict[feed_id]
            write_line(tag_file, feed_id + u',' + content)
コード例 #3
0
def data_processor_testV3():
    dataset_dir = os.path.join(DATASET_DIR, 's2a4zsV4')

    # person_path = "/Users/wangchenlong/Downloads/seeprettyface_asian_stars"
    person_path = "/Users/wangchenlong/Downloads/SCUT-FBP5500_v2/Images"
    paths_list, names_list = traverse_dir_files(person_path)

    trainA_dir = os.path.join(dataset_dir, 'trainA')
    testA_dir = os.path.join(dataset_dir, 'testA')
    mkdir_if_not_exist(trainA_dir)
    mkdir_if_not_exist(testA_dir)

    train_size = 5000
    test_size = 100
    print_size = 100

    count = 0
    random.shuffle(paths_list)
    for path in paths_list:
        img = cv2.imread(path)
        img = cv2.resize(img, (256, 256))
        if count < train_size:
            file_name = os.path.join(trainA_dir,
                                     u"c_{:04d}.jpg".format(count + 1))
            cv2.imwrite(file_name, img)
        else:
            file_name = os.path.join(testA_dir,
                                     u"c_{:04d}.jpg".format(count + 1))
            cv2.imwrite(file_name, img)
        count += 1
        if count % print_size == 0:
            print(u'[Info] run count: {}'.format(count))
        if count == train_size + test_size:
            break
    print('[Info] 数据处理完成')
コード例 #4
0
def process_csv(file_name):
    """
    处理CSV文件
    :param file_name: csv文件名
    :return: None
    """
    csv_rows = get_csv_reader(file_name)
    out_folder = SAMPLES_DIR
    mkdir_if_not_exist(out_folder, is_delete=True)

    included_cols = [0, 9, 13]  # ["ID", "标签", "描述"]
    tags_all = traverse_tags()

    out_file = os.path.join(DATA_DIR, 'hot_content-2018-08-08-17283268.txt')
    create_file(out_file)

    count = 0
    for row in csv_rows:
        count += 1
        if count == 1 or not row or len(row) < 13:  # 去掉头部
            continue
        c_row = [remove_slash(row[i]) for i in included_cols]
        [c_id, r_tag, c_content] = c_row
        c_tags = filter_content_tags(r_tag.split(','))  # 只保留最终的Tag
        for c_tag in c_tags:
            if c_tag in tags_all:
                write_line(
                    out_file,
                    c_id + u'---' + ','.join(c_tags) + u'---' + c_content)
                break

    try:
        print('CSV 处理!')
    except:
        pass
コード例 #5
0
    def __init__(self, img_folder, out_folder):
        self.img_folder = img_folder
        self.out_folder = out_folder

        mkdir_if_not_exist(out_f)  # 创建文件夹

        self.params_path = os.path.join(MODEL_DATA,
                                        'yolov3.weights')  # YOLO v3 权重文件
        self.classes_path = os.path.join(CONFIGS, 'coco.names')  # 类别文件
        self.targets_path = os.path.join(CONFIGS, 'traffic.names')

        self.classes_name = load_classes(self.classes_path)  # 加载类别目录
        self.num_classes = len(self.classes_name)  # 类别数

        self.targets_name = load_classes(self.targets_path)

        self.anchors = np.array([(10, 13), (16, 30), (33, 23), (30, 61),
                                 (62, 45), (59, 119), (116, 90), (156, 198),
                                 (373, 326)])  # anchors

        self.confidence = 0.50  # 置信度
        self.nms_thresh = 0.20  # NMS阈值
        self.input_dim = 416  # YOLOv3的检测尺寸

        gpu = '1'  # GPU
        gpu = [int(x) for x in gpu.replace(" ", "").split(",")]
        self.ctx = try_gpu(gpu)[0]  # 选择ctx

        self.net = self.load_model()  # 加载网络
コード例 #6
0
    def __init__(self, img_folder, out_folder):
        self.img_folder = img_folder
        self.out_folder = out_folder

        mkdir_if_not_exist(out_f)

        self.model_path = os.path.join(MODEL_DATA, 'yolo_weights.h5')
        self.classes_path = os.path.join(CONFIGS, 'coco_classes.txt')
        self.anchors_path = os.path.join(CONFIGS, 'yolo_anchors.txt')
コード例 #7
0
def init_city_keywords():
    kw_path = os.path.join(TXT_DATA, 'res_kw', 'cities')
    mkdir_if_not_exist(kw_path)
    if os.path.exists(kw_path):
        print('文件已存在!')
        return
    all_city = get_all_cities()

    for city in all_city:
        city_path = os.path.join(kw_path, city)
        write_line(city_path, city)
コード例 #8
0
def write_tag_keywords():
    """
    写入文本的标签
    :return: None
    """
    kw_folder = KEYWORDS_DIR
    mkdir_if_not_exist(kw_folder)
    all_tags = traverse_tags()
    for tag in all_tags:
        file_name = os.path.join(kw_folder, tag)
        write_line(file_name, tag)  # 写入全部标签
コード例 #9
0
        def _visualize_output():
            last_frame_index = 0
            last_frame_time = time.time()
            fps_history = []
            all_gaze_histories = []

            if args.fullscreen:
                cv.namedWindow('vis', cv.WND_PROP_FULLSCREEN)
                cv.setWindowProperty('vis', cv.WND_PROP_FULLSCREEN, cv.WINDOW_FULLSCREEN)

            while True:
                # If no output to visualize, show unannotated frame
                if inferred_stuff_queue.empty():
                    next_frame_index = last_frame_index + 1
                    if next_frame_index in data_source._frames:
                        next_frame = data_source._frames[next_frame_index]
                        if 'faces' in next_frame and len(next_frame['faces']) == 0:
                            if not args.headless:
                                cv.imshow('vis', next_frame['bgr'])
                            if args.record_video:
                                video_out_queue.put_nowait(next_frame_index)
                            last_frame_index = next_frame_index
                    if cv.waitKey(1) & 0xFF == ord('q'):
                        return
                    continue

                # Get output from neural network and visualize
                output = inferred_stuff_queue.get()
                bgr = None
                for j in range(batch_size):
                    frame_index = output['frame_index'][j]
                    if frame_index not in data_source._frames:
                        continue
                    frame = data_source._frames[frame_index]

                    # Decide which landmarks are usable
                    heatmaps_amax = np.amax(output['heatmaps'][j, :].reshape(-1, 18), axis=0)
                    can_use_eye = np.all(heatmaps_amax > 0.7)
                    can_use_eyelid = np.all(heatmaps_amax[0:8] > 0.75)
                    can_use_iris = np.all(heatmaps_amax[8:16] > 0.8)

                    start_time = time.time()
                    eye_index = output['eye_index'][j]
                    bgr = frame['bgr']
                    eye = frame['eyes'][eye_index]
                    eye_image = eye['image']
                    eye_side = eye['side']
                    eye_landmarks = output['landmarks'][j, :]
                    eye_radius = output['radius'][j][0]
                    if eye_side == 'left':
                        eye_landmarks[:, 0] = eye_image.shape[1] - eye_landmarks[:, 0]
                        eye_image = np.fliplr(eye_image)

                    # Embed eye image and annotate for picture-in-picture
                    eye_upscale = 2
                    eye_image_raw = cv.cvtColor(cv.equalizeHist(eye_image), cv.COLOR_GRAY2BGR)
                    eye_image_raw = cv.resize(eye_image_raw, (0, 0), fx=eye_upscale, fy=eye_upscale)
                    eye_image_annotated = np.copy(eye_image_raw)
                    if can_use_eyelid:
                        cv.polylines(
                            eye_image_annotated,
                            [np.round(eye_upscale * eye_landmarks[0:8]).astype(np.int32)
                                 .reshape(-1, 1, 2)],
                            isClosed=True, color=(255, 255, 0), thickness=1, lineType=cv.LINE_AA,
                        )
                    if can_use_iris:
                        cv.polylines(
                            eye_image_annotated,
                            [np.round(eye_upscale * eye_landmarks[8:16]).astype(np.int32)
                                 .reshape(-1, 1, 2)],
                            isClosed=True, color=(0, 255, 255), thickness=1, lineType=cv.LINE_AA,
                        )
                        cv.drawMarker(
                            eye_image_annotated,
                            tuple(np.round(eye_upscale * eye_landmarks[16, :]).astype(np.int32)),
                            color=(0, 255, 255), markerType=cv.MARKER_CROSS, markerSize=4,
                            thickness=1, line_type=cv.LINE_AA,
                        )
                    face_index = int(eye_index / 2)
                    eh, ew, _ = eye_image_raw.shape
                    v0 = face_index * 2 * eh
                    v1 = v0 + eh
                    v2 = v1 + eh

                    print('[Info] eye_side: {}'.format(eye_side))
                    u0 = 0 if eye_side == 'left' else ew
                    u1 = u0 + ew

                    bgr_h, bgr_w, _ = bgr.shape
                    if u1 > bgr_w:
                        u1 = bgr_w
                        eye_image_raw = eye_image_raw[:, 0:u1 - u0]
                        eye_image_annotated = eye_image_annotated[:, 0:u1 - u0]

                    print('[Info] bgr: {}'.format(bgr.shape))
                    print('[Info] bgr[v0:v1, u0:u1]: {}, u0: {}, u1: {}'.format(bgr[v0:v1, u0:u1].shape, u0, u1))
                    print('[Info] eye_image_raw: {}'.format(eye_image_raw.shape))

                    bgr[v0:v1, u0:u1] = eye_image_raw
                    bgr[v1:v2, u0:u1] = eye_image_annotated

                    # Visualize preprocessing results
                    frame_landmarks = (frame['smoothed_landmarks']
                                       if 'smoothed_landmarks' in frame
                                       else frame['landmarks'])
                    for f, face in enumerate(frame['faces']):
                        for landmark in frame_landmarks[f][:-1]:
                            cv.drawMarker(bgr, tuple(np.round(landmark).astype(np.int32)),
                                          color=(0, 0, 255), markerType=cv.MARKER_STAR,
                                          markerSize=2, thickness=1, line_type=cv.LINE_AA)
                        cv.rectangle(
                            bgr, tuple(np.round(face[:2]).astype(np.int32)),
                            tuple(np.round(np.add(face[:2], face[2:])).astype(np.int32)),
                            color=(0, 255, 255), thickness=1, lineType=cv.LINE_AA,
                        )

                    # Transform predictions
                    eye_landmarks = np.concatenate([eye_landmarks,
                                                    [[eye_landmarks[-1, 0] + eye_radius,
                                                      eye_landmarks[-1, 1]]]])
                    eye_landmarks = np.asmatrix(np.pad(eye_landmarks, ((0, 0), (0, 1)),
                                                       'constant', constant_values=1.0))
                    eye_landmarks = (eye_landmarks *
                                     eye['inv_landmarks_transform_mat'].T)[:, :2]
                    eye_landmarks = np.asarray(eye_landmarks)
                    eyelid_landmarks = eye_landmarks[0:8, :]
                    iris_landmarks = eye_landmarks[8:16, :]
                    iris_centre = eye_landmarks[16, :]
                    eyeball_centre = eye_landmarks[17, :]
                    eyeball_radius = np.linalg.norm(eye_landmarks[18, :] -
                                                    eye_landmarks[17, :])
                    print('[Info] eyeball_radius: {}'.format(eyeball_radius))

                    # Smooth and visualize gaze direction
                    num_total_eyes_in_frame = len(frame['eyes'])
                    if len(all_gaze_histories) != num_total_eyes_in_frame:
                        all_gaze_histories = [list() for _ in range(num_total_eyes_in_frame)]
                    gaze_history = all_gaze_histories[eye_index]
                    if can_use_eye:
                        # Visualize landmarks
                        cv.drawMarker(  # Eyeball centre
                            bgr, tuple(np.round(eyeball_centre).astype(np.int32)),
                            color=(0, 255, 0), markerType=cv.MARKER_CROSS, markerSize=4,
                            thickness=1, line_type=cv.LINE_AA,
                        )
                        # cv.circle(  # Eyeball outline
                        #     bgr, tuple(np.round(eyeball_centre).astype(np.int32)),
                        #     int(np.round(eyeball_radius)), color=(0, 255, 0),
                        #     thickness=1, lineType=cv.LINE_AA,
                        # )

                        # Draw "gaze"
                        # from models.elg import estimate_gaze_from_landmarks
                        # current_gaze = estimate_gaze_from_landmarks(
                        #     iris_landmarks, iris_centre, eyeball_centre, eyeball_radius)
                        i_x0, i_y0 = iris_centre
                        e_x0, e_y0 = eyeball_centre
                        theta = -np.arcsin(np.clip((i_y0 - e_y0) / eyeball_radius, -1.0, 1.0))
                        phi = np.arcsin(np.clip((i_x0 - e_x0) / (eyeball_radius * -np.cos(theta)),
                                                -1.0, 1.0))
                        current_gaze = np.array([theta, phi])
                        gaze_history.append(current_gaze)
                        gaze_history_max_len = 10
                        if len(gaze_history) > gaze_history_max_len:
                            gaze_history = gaze_history[-gaze_history_max_len:]
                        util.gaze.draw_gaze(bgr, iris_centre, np.mean(gaze_history, axis=0),
                                            length=120.0, thickness=1)
                    else:
                        gaze_history.clear()

                    if can_use_eyelid:
                        cv.polylines(
                            bgr, [np.round(eyelid_landmarks).astype(np.int32).reshape(-1, 1, 2)],
                            isClosed=True, color=(255, 255, 0), thickness=1, lineType=cv.LINE_AA,
                        )

                    if can_use_iris:
                        cv.polylines(
                            bgr, [np.round(iris_landmarks).astype(np.int32).reshape(-1, 1, 2)],
                            isClosed=True, color=(0, 255, 255), thickness=1, lineType=cv.LINE_AA,
                        )
                        cv.drawMarker(
                            bgr, tuple(np.round(iris_centre).astype(np.int32)),
                            color=(0, 255, 255), markerType=cv.MARKER_CROSS, markerSize=4,
                            thickness=1, line_type=cv.LINE_AA,
                        )

                    print('[Info] 绘制完成!')
                    frames_dir = os.path.join(DATA_DIR, "frames")
                    mkdir_if_not_exist(frames_dir)
                    frame_path = os.path.join(frames_dir, '{}.out.jpg'.format(frame_index))
                    print('[Info] 写入视频帧: {}'.format(frame_path))
                    cv.imwrite(frame_path, bgr)
コード例 #10
0
def main():
    frames_dir = os.path.join(DATA_DIR, 'frames')
    mkdir_if_not_exist(VIDS_DIR)
    mkdir_if_not_exist(frames_dir)

    # from_video = os.path.join(VIDS_DIR, "normal_video.mp4")
    from_video = os.path.join(VIDS_DIR, "vid_no_glasses.mp4")

    # record_video = os.path.join(DATA_DIR, "normal_video.out.mp4")

    coloredlogs.install(
        datefmt='%d/%m %H:%M',
        fmt='%(asctime)s %(levelname)s %(message)s',
        level="INFO",
    )

    # Check if GPU is available
    from tensorflow.python.client import device_lib

    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        allow_growth=True))
    gpu_available = False
    try:
        gpus = [
            d for d in device_lib.list_local_devices(config=session_config)
            if d.device_type == 'GPU'
        ]
        gpu_available = len(gpus) > 0
    except Exception as e:
        print('[Info] GPU异常,使用CPU!')

    print('[Info] 是否启用GPU: {}'.format(gpu_available))
    # -----------------------------------------------------------------------#

    tf.logging.set_verbosity(tf.logging.INFO)
    session = tf.Session(config=session_config)

    batch_size = 2  # 设置batch
    print('[Info] 输入视频路径: {}'.format(from_video))
    assert os.path.isfile(from_video)

    # 模型包括大模型和小模型
    data_source = Video(
        from_video,
        tensorflow_session=session,
        batch_size=batch_size,
        data_format='NCHW' if gpu_available else 'NHWC',
        # eye_image_shape=(108, 180)
        eye_image_shape=(36, 60))

    # Define model
    model = ELG(
        session,
        train_data={'videostream': data_source},
        # first_layer_stride=3,
        first_layer_stride=1,
        # num_modules=3,
        num_modules=2,
        # num_feature_maps=64,
        num_feature_maps=32,
        learning_schedule=[
            {
                'loss_terms_to_optimize': {
                    'dummy': ['hourglass', 'radius']
                },
            },
        ],
    )

    infer = model.inference_generator()

    count = 0

    while True:
        print('')
        print('-' * 50)
        output = next(infer)
        process_output(output, batch_size, data_source, frames_dir)  # 处理输出
        count += 1
        print('count: {}'.format(count))
        if count == 10:
            break
コード例 #11
0
    def export_model(self):
        """
        参数名称:
        output: self.test_fake_B: Tensor("generator_B/Tanh:0", shape=(1, 256, 256, 3), dtype=float32)
        input: self.test_domain_A: Tensor("test_domain_A:0", shape=(1, 256, 256, 3), dtype=float32)

        [Info] input_tensor: Tensor("input_1:0", shape=(?, 224, 224, 3), dtype=float32)
        [Info] output_tensors.values(): [<tf.Tensor 'dense_1/Softmax:0' shape=(?, 10) dtype=float32>]
        :return:
        """
        from smnn.io.parser import sm_kv_record_parser

        data_schema = os.path.join(DATA_DIR, 'schema.json')
        sm_parser = sm_kv_record_parser.SmKVRecordParser(
            data_schema, '[dat]', '[common]')
        sm_parser.init()

        raw_input_tensor = tf.placeholder(tf.string, [None])
        tensor_dict = sm_parser.get_tensor_dict(raw_input_tensor)
        image_b64 = tensor_dict['image']  # 来源于data_schema

        image = tf.decode_base64(image_b64)
        image = tf.decode_raw(image, tf.float32)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.reshape(image, [-1, 256, 256, 3])  # 图像
        img_tensor = image / 127.5 - 1.

        input_tensor = tf.get_default_graph().get_tensor_by_name(
            "{}:0".format("test_domain_A"))
        output_tensor = tf.get_default_graph().get_tensor_by_name(
            "{}:0".format("generator_B/Tanh"))

        res_ops = tf.contrib.graph_editor.graph_replace(
            [output_tensor], {input_tensor: img_tensor})

        inputs = {"input": raw_input_tensor}  # 输入String图像
        outputs = {"output": res_ops[0]}  # 输出

        prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def(
            inputs, outputs)
        signature_map = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            prediction_signature
        }

        legacy_op = control_flow_ops.group(
            tf.local_variables_initializer(),
            resources.initialize_resources(resources.shared_resources()),
            tf.tables_initializer())

        res_dir = os.path.join(DATA_DIR, 'model-tf')
        mkdir_if_not_exist(res_dir)
        builder = saved_model_builder.SavedModelBuilder(res_dir)

        builder.add_meta_graph_and_variables(self.sess,
                                             [tag_constants.SERVING],
                                             signature_def_map=signature_map,
                                             legacy_init_op=legacy_op)

        builder.save()
        print('[Info] 模型导出完成!')
コード例 #12
0
 def __init__(self):
     self.gp = GazePredicter()  # 目光预测
     self.frames_dir = os.path.join(VIDS_DIR, 'frames')
     self.out_path = os.path.join(
         VIDS_DIR, 'out.{}.mp4'.format(get_current_time_str()))
     mkdir_if_not_exist(self.frames_dir)
コード例 #13
0
def main():
    ym = Y3Model()
    img_folder = os.path.join(IMG_DATA, 'jiaotong-0727')
    right_folder = os.path.join(IMG_DATA, 'jiaotong-0727-right')
    wrong_folder = os.path.join(IMG_DATA, 'jiaotong-0727-wrong')
    none_folder = os.path.join(IMG_DATA, 'jiaotong-0727-none')
    mkdir_if_not_exist(right_folder, is_delete=True)
    mkdir_if_not_exist(wrong_folder, is_delete=True)
    mkdir_if_not_exist(none_folder, is_delete=True)
    img_dict = format_img_and_anno(img_folder)

    r_count = 0
    all_count = 0
    no_recall_count = 0

    for count, img_name in enumerate(img_dict):
        (img_p, anno_p) = img_dict[img_name]
        # print(img_p)
        try:
            tag_res, img_box = ym.detect_img(img_p, True)
        except Exception as e:
            continue

        w_tags = []
        for tag in tag_res.keys():
            if tag_res[tag] <= 0.01:
                print_info('删除Tag {} {}'.format(tag, tag_res[tag]))
                w_tags.append(tag)
        for tag in w_tags:
            tag_res.pop(tag, None)  # 小于1%的类别

        all_count += 1
        p_classes = set(tag_res.keys())

        _, t_classes = read_anno_xml(anno_p)
        merge_dict = {'truck': 'car', 'bus': 'car', 'car': 'car'}  # 合并类别
        t_classes = map_classes(merge_dict, t_classes)  # 合并类别
        traffic_names = [
            'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train',
            'truck', 'boat'
        ]
        t_classes = set(t_classes) & set(traffic_names)

        img_name = img_p.split('/')[-1]
        is_right = False

        if p_classes and p_classes.issubset(t_classes):  # 检测正确
            r_count += 1
            img_box.save(os.path.join(right_folder, img_name + '.d.jpg'))
            is_right = True
        elif not p_classes and not t_classes:  # 空,检测正确
            r_count += 1
            if not img_box:
                img_box = Image.open(img_p)
            img_box.save(os.path.join(right_folder, img_name + '.d.jpg'))
            is_right = True
        elif not p_classes and t_classes:  # 检测为空,实际有类
            if not img_box:
                img_box = Image.open(img_p)
            img_box.save(os.path.join(none_folder, img_name + '.d.jpg'))
            no_recall_count += 1  # 未召回
            r_count += 1
            is_right = True
        else:  # 其他,检测错误
            if not img_box:
                img_box = Image.open(img_p)
            img_box.save(os.path.join(wrong_folder, img_name + '.d.jpg'))

        print_info('P: {}, T: {}, {}'.format(list(p_classes), list(t_classes),
                                             '正确' if is_right else '错误'))

    right_ratio = safe_div(r_count, all_count)
    print_info('正确: {}, 全部: {}, 未召回: {}, 准确率: {}'.format(
        r_count, all_count, no_recall_count, right_ratio))
コード例 #14
0
    def process_dataset(self):
        """
        处理数据集
        """
        c_paths_list, c_names_list = traverse_dir_files(self.cartoons_path)
        p_paths_list, p_names_list = traverse_dir_files(self.persons_path)

        random.seed(47)
        random.shuffle(c_paths_list)
        random.shuffle(p_paths_list)

        train_size = 1500  # 训练集量
        test_size = 100  # 测试集量
        print_size = 100

        count = 0
        train_person_dir = os.path.join(ROOT_DIR, 'dataset', 's2a4zsV1',
                                        'trainA')
        test_person_dir = os.path.join(ROOT_DIR, 'dataset', 's2a4zsV1',
                                       'testA')
        mkdir_if_not_exist(train_person_dir)
        mkdir_if_not_exist(test_person_dir)

        print('[Info] 真人样本总数: {}'.format(len(p_paths_list)))
        for p_path in p_paths_list:
            try:
                p_img = cv2.imread(p_path)
                p_img = cv2.resize(p_img, (256, 256))

                if count < train_size:
                    p_file_name = os.path.join(
                        train_person_dir, u"p_{:04d}.jpg".format(count + 1))
                else:
                    p_file_name = os.path.join(
                        test_person_dir, u"p_{:04d}.jpg".format(count + 1))

                cv2.imwrite(p_file_name, p_img)
                count += 1
            except Exception as e:
                print('[Error] error {}'.format(e))
                continue

            if count % print_size == 0:
                print(u'[Info] run count: {}'.format(count))

            if count == train_size + test_size:
                break

        train_cartoon_dir = os.path.join(ROOT_DIR, 'dataset', 's2a4zsV1',
                                         'trainB')
        test_cartoon_dir = os.path.join(ROOT_DIR, 'dataset', 's2a4zsV1',
                                        'testB')
        mkdir_if_not_exist(train_cartoon_dir)
        mkdir_if_not_exist(test_cartoon_dir)

        count = 0
        print('[Info] 卡通样本总数: {}'.format(len(c_paths_list)))
        for c_path in c_paths_list:
            try:
                c_img = cv2.imread(c_path)
                c_img = cv2.resize(c_img, (256, 256))

                if count < train_size:
                    c_file_name = os.path.join(
                        train_cartoon_dir, u"c_{:04d}.jpg".format(count + 1))
                    cv2.imwrite(c_file_name, c_img)
                else:
                    c_file_name = os.path.join(
                        test_cartoon_dir, u"c_{:04d}.jpg".format(count + 1))
                    cv2.imwrite(c_file_name, c_img)

                count += 1
            except Exception as e:
                print('[Error] error {}'.format(e))
                continue

            if count % print_size == 0:
                print(u'[Info] run count: {}'.format(count))

            if count == train_size + test_size:
                break

        print('[Info] 数据处理完成')