コード例 #1
0
def regenerate_miss_mask(database, mp_num, miss_file_path):
    pool = mp.Pool()
    AU_couple_dict = get_zip_ROI_AU()
    orig_file_path_lst = set()
    landmark = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
    with open(miss_file_path, 'r') as file_obj:
        for line in file_obj:
            miss_mask_path = line.strip()
            subject_name = miss_mask_path.split("/")[-3]
            sequenece_name = miss_mask_path.split("/")[-2]
            frame = miss_mask_path.split("/")[-1]
            frame = frame[:frame.index("_")]
            orig_train_file_path = config.DATA_PATH[
                "BP4D"] + "/release/BP4D-training/{0}/{1}/{2}.jpg".format(
                    subject_name, sequenece_name, frame)
            orig_file_path_lst.add(orig_train_file_path)

    split_list = lambda A, n=3: [A[i:i + n] for i in range(0, len(A), n)]
    sub_list = split_list(list(orig_file_path_lst),
                          len(orig_file_path_lst) // mp_num)
    for sub in sub_list:
        pool.apply_async(sub_process_miss_mask,
                         args=(database, landmark, sub, AU_couple_dict))
    pool.close()
    pool.join()
コード例 #2
0
def check_box_and_cropface(orig_img_path, channel_first=False):

    cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_box(
        orig_img_path, channel_first=channel_first, mc_manager=None)

    AU_couple = get_zip_ROI_AU()
    already_couple = set()
    # cropped_face = np.transpose(cropped_face, (2, 0, 1))
    # cropped_face, params = transforms.random_flip(
    #     cropped_face, x_random=True, return_param=True)
    # cropped_face = np.transpose(cropped_face, (1, 2, 0))
    i = 0
    for AU, box_ls in AU_mask_dict.items():
        current_AU_couple = AU_couple[AU]
        if current_AU_couple in already_couple:
            continue
        already_couple.add(current_AU_couple)
        for box in box_ls:
            box = np.asarray([box])
            box = transforms.flip_bbox(box, (512, 512), x_flip=False)
            x_min, y_min = box[0][1], box[0][0]
            x_max, y_max = box[0][3], box[0][2]
            print(box)
            cp_croped = cropped_face.copy()
            cv2.rectangle(cp_croped, (x_min, y_min), (x_max, y_max),
                          (0, 255, 0), 1)

            cv2.imwrite(
                "/home2/mac/test1/AU_{0}_{1}.png".format(
                    ",".join(current_AU_couple), i), cp_croped)
            i += 1
    print(i)
コード例 #3
0
def adaptive_AU_relation(database_name):
    '''
    must called after adaptive_AU_database
    从config.AU_RELATION_BP4D中删去同一个区域的AU

    比如说config.AU_RELATION_BP4D 中有(10, 12) 但10和12已经是同一个区域了,因此删掉这种关系。

    另外需要注意另一种情况:再比如说config.AU_RELATION_BP4D 中有(10, 13),但10和12再字典中组合成新的区域了,而13也和另一个AU组合成了新的区域。

    '''
    new_AU_relation = list()

    AU_couple = get_zip_ROI_AU()
    already_same_region_set = set()
    for AU, couple_tuple in AU_couple.items():
        for AU_a, AU_b in itertools.combinations(couple_tuple, 2):
            already_same_region_set.add(tuple(sorted([int(AU_a), int(AU_b)])))
    if database_name == "BP4D":
        for AU_tuple in config.AU_RELATION_BP4D:
            if tuple(sorted(AU_tuple)) not in already_same_region_set:
                new_AU_relation.append(tuple(sorted(AU_tuple)))
        config.AU_RELATION_BP4D = new_AU_relation
    elif database_name == "DISFA":
        for AU_tuple in config.AU_RELATION_DISFA:
            if tuple(sorted(AU_tuple)) not in already_same_region_set:
                new_AU_relation.append(tuple(sorted(AU_tuple)))
        config.AU_RELATION_DISFA = new_AU_relation
コード例 #4
0
    def __init__(self,
                 database,
                 L,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 train_all_data,
                 two_stream_mode,
                 pretrained_target="",
                 paper_report_label_idx=None):
        self.database = database
        self.split_name = split_name
        self.L = L  # used for the optical flow image fetch at before L/2 and after L/2
        self.au_couple_dict = get_zip_ROI_AU()
        self.mc_manager = mc_manager
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.pretrained_target = pretrained_target
        self.two_stream_mode = two_stream_mode
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA
        self.paper_report_label_idx = paper_report_label_idx
        if train_all_data:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{}_fold".format(fold), "full_pretrain.txt")
        else:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{0}_fold".format(fold),
                "id_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    relative_path, au_set_str, _, current_database_name = line.split(
                        "\t")
                    AU_set = set(
                        AU for AU in au_set_str.split(',')
                        if AU in config.AU_ROI and AU in config.AU_SQUEEZE.inv)
                    if au_set_str == "0":
                        AU_set = set()
                    rgb_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + relative_path  # id file 是相对路径
                    flow_path = config.FLOW_PATH[
                        current_database_name] + os.path.sep + relative_path
                    if os.path.exists(rgb_path):
                        self.result_data.append((rgb_path, flow_path, AU_set,
                                                 current_database_name))

        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #5
0
ファイル: AU_dataset.py プロジェクト: zhangxujinsh/AU_R-CNN
    def __init__(self,
                 img_resolution,
                 database,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 prefix="",
                 pretrained_target=""):
        self.database = database
        self.img_resolution = img_resolution
        self.split_name = split_name
        self.au_couple_dict = get_zip_ROI_AU()
        self.mc_manager = mc_manager
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.pretrained_target = pretrained_target
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA

        id_list_file_path = os.path.join(
            self.dir + "/idx/{0}_fold{1}".format(fold, prefix),
            "intensity_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        self.video_offset = OrderedDict()
        self.video_count = defaultdict(int)
        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    img_path, au_set_str, from_img_path, current_database_name = line.split(
                        "\t")
                    AU_intensity = np.fromstring(au_set_str,
                                                 dtype=np.int32,
                                                 sep=',')
                    from_img_path = img_path if from_img_path == "#" else from_img_path
                    img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + img_path  # id file 是相对路径
                    from_img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + from_img_path
                    video_id = "/".join(
                        [img_path.split("/")[-3],
                         img_path.split("/")[-2]])
                    if video_id not in self.video_offset:
                        self.video_offset[video_id] = len(self.result_data)
                    self.video_count[video_id] += 1
                    if os.path.exists(img_path):
                        self.result_data.append(
                            (img_path, from_img_path, AU_intensity,
                             current_database_name))
        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #6
0
    def __init__(self,
                 database,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 train_all_data=False,
                 read_type="rgb",
                 pretrained_target="",
                 img_resolution=config.IMG_SIZE[0]):
        self.database = database
        self.img_resolution = img_resolution
        self.split_name = split_name  # trainval or test
        self.read_type = read_type
        self.au_couple_dict = get_zip_ROI_AU()
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA
        self.pretrained_target = pretrained_target
        self.mc_manager = mc_manager
        if train_all_data:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{}_fold".format(fold), "full_pretrain.txt")
        else:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{0}_fold".format(fold),
                "id_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    img_path, au_set_str, from_img_path, current_database_name = line.split(
                        "\t")
                    AU_set = set(AU for AU in au_set_str.split(',')
                                 if AU in config.AU_ROI)
                    if au_set_str == "0":
                        AU_set = set()
                    from_img_path = img_path if from_img_path == "#" else from_img_path

                    img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + img_path  # id file 是相对路径
                    from_img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + from_img_path
                    if os.path.exists(img_path):
                        self.result_data.append(
                            (img_path, from_img_path, AU_set,
                             current_database_name))
        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #7
0
 def __init__(self, host):
     self.AU_couple = get_zip_ROI_AU()
     if import_lib:
         self.mc = mc.Client([host],
                             binary=True,
                             behaviors={
                                 'tcp_nodelay': True,
                                 "ketama": True
                             })
     else:
         self.mc = mc.Client([host], debug=0)
コード例 #8
0
def check_BP4D_multi_label(label_file_dir, au_mask_dir):
    au_couple_dict = get_zip_ROI_AU()
    combine_AU = defaultdict(set)
    all_AU = set()
    all_co_occur_AU = defaultdict(int)
    for file_name in os.listdir(label_file_dir):
        subject_name = file_name[:file_name.index("_")]
        sequence_name = file_name[file_name.index("_") +
                                  1:file_name.rindex(".")]
        AU_column_idx = {}
        co_occur_AU = {}
        with open(label_file_dir + "/" + file_name, "r") as au_file_obj:
            for idx, line in enumerate(au_file_obj):
                if idx == 0:  # header specify Action Unit
                    for col_idx, AU in enumerate(line.split(",")[1:]):
                        AU_column_idx[AU] = col_idx + 1  # read header
                    continue  # read head over , continue

                lines = line.split(",")
                frame = lines[0]
                non_label_lst = [
                    int(lines[AU_column_idx[AU]])
                    for AU in config.AU_ROI.keys()
                    if int(AU) >= 40 and int(lines[AU_column_idx[AU]]) != 9
                ]
                if len(non_label_lst) > 0:
                    print(" big AU is not 9! in {}".format(label_file_dir +
                                                           "/" + file_name))
                    print("{}".format(non_label_lst))
                au_label_dict = {AU: int(lines[AU_column_idx[AU]]) for AU in config.AU_ROI.keys() \
                                 if int(lines[AU_column_idx[AU]]) == 1} # 注意只生成=1的字典,而不要=9的字典,就是只要非unknown的AU

                for AU_tuple in combinations(au_label_dict.keys(), 2):
                    co_occur_AU[tuple(sorted(map(int, AU_tuple)))] = 1

                all_AU.update(au_label_dict.keys())
                au_mask_dict = {AU: "{0}/{1}/{2}/{3}_AU_{4}.png".format(au_mask_dir,
                                                                        subject_name, sequence_name,
                                                                        frame, ",".join(au_couple_dict[AU])) \
                                for AU in au_label_dict.keys()}

                for AU, mask_path in au_mask_dict.items():
                    combine_AU[mask_path].add(AU)
        for AU_tuple, count in co_occur_AU.items():
            all_co_occur_AU[AU_tuple] += count
    AU_counter = defaultdict(int)
    for mask_path, combine_AU in combine_AU.items():
        AU_counter[tuple(sorted(combine_AU))] += 1
    # for AU, count in sorted(AU_counter.items(), key=lambda e: e[1],reverse=True):
    #     print("{0} {1}".format(AU, count))
    return AU_counter, all_AU, all_co_occur_AU
コード例 #9
0
def generate_AUCouple_ROI_mask_image(database_name, img_path):
    adaptive_AU_database(database_name)
    global MASK_COLOR

    mask_color_lst = []
    for color in MASK_COLOR:
        mask_color_lst.append(color_bgr(color))
    cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
        img_path, channel_first=False)
    AU_couple_dict = get_zip_ROI_AU()

    land = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
    landmark, _, _ = land.landmark(image=cropped_face)
    roi_polygons = land.split_ROI(landmark)
    for roi_no, polygon_vertex_arr in roi_polygons.items():
        polygon_vertex_arr[0, :] = np.round(polygon_vertex_arr[0, :])
        polygon_vertex_arr[1, :] = np.round(polygon_vertex_arr[1, :])
        polygon_vertex_arr = sort_clockwise(polygon_vertex_arr.tolist())
        cv2.polylines(cropped_face, [polygon_vertex_arr],
                      True,
                      color_bgr(RED),
                      thickness=1)
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(cropped_face,
                    str(roi_no),
                    tuple(
                        np.mean(polygon_vertex_arr, axis=0).astype(np.int32)),
                    font,
                    0.7, (0, 255, 255),
                    thickness=1)
    already_fill_AU = set()
    idx = 0
    gen_face_lst = dict()
    AU_couple_mask = dict()
    for AU in config.AU_ROI.keys():
        AU_couple = AU_couple_dict[AU]
        if AU_couple in already_fill_AU:
            continue
        already_fill_AU.add(AU_couple)
        mask = AU_mask_dict[AU]
        AU_couple_mask[AU_couple] = mask
        color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)

        color_mask[mask != 0] = random.choice(mask_color_lst)
        idx += 1
        new_face = cv2.addWeighted(cropped_face, 0.75, color_mask, 0.25, 0)
        gen_face_lst[AU_couple] = new_face
    return gen_face_lst, AU_couple_mask
コード例 #10
0
ファイル: table_script.py プロジェクト: zhangxujinsh/AU_R-CNN
def generate_AU_ROI_table():
    already_AU_couple = set()
    AU_couple_dict = get_zip_ROI_AU()
    AU_couple_region = OrderedDict()
    for AU, region_numbers in sorted(config.AU_ROI.items(), key=lambda e:e[0]):
        AU_couple = AU_couple_dict[AU]
        if AU_couple not in already_AU_couple:
            already_AU_couple.add(AU_couple)
            AU_couple_region[AU_couple] = sorted(region_numbers)

    for AU_couple, region_numbers in AU_couple_region.items():
        AU_info = []
        for AU in AU_couple:
            AU_info.append("AU {}".format(AU))
        AU_couple = " , ".join(AU_info)
        region_numbers = " , ".join(map(str, region_numbers))
        print("{0} & {1} \\\\".format(AU_couple, region_numbers))
        print("\hline")
コード例 #11
0
ファイル: demo_AU_rcnn.py プロジェクト: zhangxujinsh/AU_R-CNN
    def generate_AUCouple_ROI_mask_image(self, database_name, img_path,
                                         roi_activate):
        adaptive_AU_database(database_name)

        cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
            img_path, channel_first=False)
        AU_couple_dict = get_zip_ROI_AU()

        land = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
        landmark, _, _ = land.landmark(image=cropped_face)
        roi_polygons = land.split_ROI(landmark)
        for roi_no, polygon_vertex_arr in roi_polygons.items():
            polygon_vertex_arr[0, :] = np.round(polygon_vertex_arr[0, :])
            polygon_vertex_arr[1, :] = np.round(polygon_vertex_arr[1, :])
            polygon_vertex_arr = sort_clockwise(polygon_vertex_arr.tolist())
            cv2.polylines(cropped_face, [polygon_vertex_arr],
                          True, (0, 0, 255),
                          thickness=1)
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(cropped_face,
                        str(roi_no),
                        tuple(
                            np.mean(polygon_vertex_arr,
                                    axis=0).astype(np.int32)),
                        font,
                        0.7, (0, 255, 255),
                        thickness=1)
        already_fill_AU = set()
        AUCouple_face_dict = dict()
        for AU in config.AU_ROI.keys():
            AU_couple = AU_couple_dict[AU]
            if AU_couple in already_fill_AU or AU_couple not in roi_activate:
                continue
            already_fill_AU.add(AU_couple)
            mask = AU_mask_dict[AU]
            color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
            color_mask[mask != 0] = (199, 21, 133)
            new_face = cv2.add(cropped_face, color_mask)
            AUCouple_face_dict[AU_couple] = new_face

        return AUCouple_face_dict
コード例 #12
0
def collect_box_coordinate(image_path, mc_manager, database):
    au_couple_dict = get_zip_ROI_AU()
    AU_group_box_coodinate = defaultdict(list)  # AU_group => box_list
    key_prefix = database + "@512|"
    try:
        cropped_face, AU_box_dict = FaceMaskCropper.get_cropface_and_box(
            image_path,
            image_path,
            channel_first=True,
            mc_manager=mc_manager,
            key_prefix=key_prefix)
    except IndexError:
        return AU_group_box_coodinate
    for AU, box_list in AU_box_dict.items():
        AU_couple = au_couple_dict[AU]
        if AU_couple in AU_group_box_coodinate:
            continue
        new_box_list = [list(box) for box in box_list]
        new_box_list.sort(key=lambda e: e[1])
        AU_group_box_coodinate[AU_couple].extend(new_box_list)
    return AU_group_box_coodinate
コード例 #13
0
def generate_face_AUregion_mask(database, iter_crop):

    landmark = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
    AU_couple_dict = get_zip_ROI_AU()
    already_mask = set()
    for new_face, orig_img, rect, absolute_path in iter_crop:  # absolute_path 是截取crop新脸的路径
        try:
            absolute_path = absolute_path.replace("//", "/")

            subject_name = absolute_path.split(os.sep)[-3]
            sequence = absolute_path.split(os.sep)[-2]
            filename = os.path.basename(absolute_path)
            filename = filename[:filename.rindex(".")]
            au_mask_dir = "{0}/{1}/{2}/".format(
                config.AU_REGION_MASK_PATH[database], subject_name, sequence)
            for f in glob.glob("{0}/{1}_AU_*".format(au_mask_dir, filename)):
                os.remove(f)
            for AU in config.AU_ROI.keys():
                if AU_couple_dict[AU] in already_mask:
                    continue
                mask = crop_face_img_mask(AU,
                                          orig_img,
                                          new_face,
                                          rect,
                                          landmarker=landmark)
                if not os.path.exists(au_mask_dir):
                    os.makedirs(au_mask_dir)
                    print("make dir {}".format(au_mask_dir))
                au_couple = AU_couple_dict[AU]
                already_mask.add(au_couple)
                au_couple = ",".join(au_couple)
                au_mask_path = "{0}/{1}_AU_{2}.png".format(
                    au_mask_dir, filename, au_couple)

                cv2.imwrite(au_mask_path, mask)
                print("write : {}".format(au_mask_path))
                yield mask, au_mask_path
            already_mask.clear()
        except IndexError:
            continue
コード例 #14
0
def stats_AU_group_area(image_path, mc_manager, database):
    au_couple_dict = get_zip_ROI_AU()
    AU_group_box_area = dict()
    key_prefix = database + "|"
    try:
        cropped_face, AU_box_dict = FaceMaskCropper.get_cropface_and_box(
            image_path,
            image_path,
            channel_first=True,
            mc_manager=mc_manager,
            key_prefix=key_prefix)
    except IndexError:
        return AU_group_box_area
    for AU, box_list in AU_box_dict.items():
        AU_couple = au_couple_dict[AU]
        tot_area = 0.
        for box in box_list:
            y_min, x_min, y_max, x_max = box
            area = (x_max - x_min) * (y_max - y_min)
            tot_area += area
        tot_area /= len(box_list)
        AU_group_box_area[AU_couple] = tot_area
    return AU_group_box_area
コード例 #15
0
def async_crop_face(database, src_img_dir, mp_num, force_write=True):
    pool = mp.Pool()
    data_path_lst = set()
    not_contain_count = 0
    contain_count = 0
    AU_couple_dict = get_zip_ROI_AU()
    for root, dirs, files in os.walk(src_img_dir):
        for file in files:
            absolute_path = root + os.sep + file
            subject_name = absolute_path.split(os.sep)[-3]
            sequence = absolute_path.split(os.sep)[-2]

            dst_dir = config.CROP_DATA_PATH[
                database] + os.sep + subject_name + os.sep + sequence + os.sep
            cropped_path = dst_dir + os.sep + os.path.basename(absolute_path)
            data_path_lst.add(absolute_path)

    split_list = lambda A, n=3: [A[i:i + n] for i in range(0, len(A), n)]
    print(len(data_path_lst) // mp_num)
    sub_list = split_list(list(data_path_lst), len(data_path_lst) // mp_num)
    for sub in sub_list:
        pool.apply_async(sub_process, args=(database, sub, force_write))
    pool.close()
    pool.join()
コード例 #16
0
    data_base = options.database
    data_path = options.data_path
    if data_path is None:
        opt.error("error, data_path not given!")
    min_support = options.min_support
    if min_support is None:
        opt.error("error, min_support not given!")
    if data_base == "BP4D":
        data = read_BP4D_AU_file(data_path)
    elif data_base == "DISFA":
        data = read_BP4D_AU_file(data_path)
    import pyfpgrowth
    from collections import defaultdict
    import config
    from dataset_toolkit.compress_utils import get_zip_ROI_AU, get_AU_couple_child
    au_couple_dict = get_zip_ROI_AU()
    roi_set = set(list(au_couple_dict.values()))
    AU_belong_roi_id = {"0": 999}

    for idx, couple in enumerate(list(set(au_couple_dict.values()))):
        for AU in couple:
            AU_belong_roi_id[AU] = idx
    DISFA_roi_AU = defaultdict(list)
    if data_base == "DISFA":
        all_print_tuple = set()
        for AU in config.DISFA_use_AU:
            DISFA_roi_AU[AU_belong_roi_id[AU]].append(AU)
        for entry in itertools.product(*DISFA_roi_AU.values()):
            print_tuple = list(map(str, sorted(map(int, list(entry)))))
            while len(print_tuple) < config.BOX_NUM[data_base]:
                AU_1 = random.choice(["1", "2", "5"])
コード例 #17
0
ファイル: demo_AU_rcnn.py プロジェクト: zhangxujinsh/AU_R-CNN
def main():
    parser = argparse.ArgumentParser(
        description='generate Graph desc file script')
    parser.add_argument('--mean',
                        default=config.ROOT_PATH + "BP4D/idx/mean_rgb.npy",
                        help='image mean .npy file')
    parser.add_argument("--image",
                        default='C:/Users/machen/Downloads/tmp/face.jpg')
    parser.add_argument(
        "--model", default="C:/Users/machen/Downloads/tmp/BP4D_3_fold_1.npz")
    parser.add_argument("--pretrained_model_name",
                        '-premodel',
                        default='resnet101')
    parser.add_argument('--database', default='BP4D', help='Output directory')
    parser.add_argument('--device',
                        default=0,
                        type=int,
                        help='GPU device number')
    args = parser.parse_args()
    adaptive_AU_database(args.database)

    if args.pretrained_model_name == "resnet101":
        faster_rcnn = FasterRCNNResnet101(
            n_fg_class=len(config.AU_SQUEEZE),
            pretrained_model="resnet101",
            mean_file=args.mean,
            use_lstm=False,
            extract_len=1000
        )  # 可改为/home/machen/face_expr/result/snapshot_model.npz
    elif args.pretrained_model_name == "vgg":
        faster_rcnn = FasterRCNNVGG16(n_fg_class=len(config.AU_SQUEEZE),
                                      pretrained_model="imagenet",
                                      mean_file=args.mean,
                                      use_lstm=False,
                                      extract_len=1000)

    if os.path.exists(args.model):
        print("loading pretrained snapshot:{}".format(args.model))
        chainer.serializers.load_npz(args.model, faster_rcnn)
    if args.device >= 0:
        faster_rcnn.to_gpu(args.device)
        chainer.cuda.get_device_from_id(int(args.device)).use()

    heatmap_gen = HeatMapGenerator(np.load(args.model), use_relu=True)
    if args.device >= 0:
        heatmap_gen.to_gpu(args.device)
    cropped_face, AU_box_dict = FaceMaskCropper.get_cropface_and_box(
        args.image, args.image, channel_first=True)
    au_couple_dict = get_zip_ROI_AU()
    au_couple_child = get_AU_couple_child(
        au_couple_dict)  # AU couple tuple => child fetch list
    au_couple_box = dict()  # value is box (4 tuple coordinate) list

    for AU, AU_couple in au_couple_dict.items():
        au_couple_box[AU_couple] = AU_box_dict[AU]
    box_lst = []
    roi_no_AU_couple_dict = dict()
    roi_no = 0
    for AU_couple, couple_box_lst in au_couple_box.items():
        box_lst.extend(couple_box_lst)
        for _ in couple_box_lst:
            roi_no_AU_couple_dict[roi_no] = AU_couple
            roi_no += 1

    box_lst = np.asarray(box_lst)
    cropped_face = cropped_face.astype(np.float32)
    orig_face = cropped_face
    cropped_face = faster_rcnn.prepare(
        cropped_face)  # substract mean pixel value
    box_lst = box_lst.astype(np.float32)
    orig_box_lst = box_lst
    batch = [
        (cropped_face, box_lst),
    ]
    cropped_face, box_lst = concat_examples(
        batch, args.device)  # N,3, H, W, ;  N, F, 4

    if box_lst.shape[1] != config.BOX_NUM[args.database]:
        print("error box num {0} != {1}".format(box_lst.shape[1],
                                                config.BOX_NUM[args.database]))
        return
    with chainer.no_backprop_mode(), chainer.using_config("train", False):
        cropped_face = chainer.Variable(cropped_face)
        box_lst = chainer.Variable(box_lst)
        roi_preds, _ = faster_rcnn.predict(cropped_face, box_lst)  # R, 22
        roi_feature_maps = faster_rcnn.extract(orig_face, orig_box_lst,
                                               'res5')  # R, 2048 7,7

        roi_images = []
        box_lst = box_lst[0].data.astype(np.int32)
        for box in box_lst:
            y_min, x_min, y_max, x_max = box
            roi_image = orig_face[:, y_min:y_max + 1,
                                  x_min:x_max + 1]  # N, 3, roi_H, roi_W
            roi_images.append(roi_image)  # list of  N, 3, roi_H, roi_W
        cmap = plt.get_cmap('jet')
        # image_activate_map = np.zeros((cropped_face.shape[2], cropped_face.shape[3]), dtype=np.float32)
        for box_id, (roi_image, roi_feature_map) in enumerate(
                zip(roi_images, roi_feature_maps)):
            y_min, x_min, y_max, x_max = box_lst[box_id]
            # 22, roi_h, roi_w, 3
            xp = chainer.cuda.get_array_module(roi_feature_map)
            roi_feature_map = xp.expand_dims(roi_feature_map, 0)
            #   class_roi_overlay_img = 22, roi_h, roi_w
            class_roi_activate_img = heatmap_gen.generate_activate_roi_map(
                roi_feature_map, (y_max - y_min + 1, x_max - x_min + 1))
            roi_pred = roi_preds[box_id]  # 22
            # choice_activate_map = np.zeros((y_max-y_min+1, x_max-x_min+1), dtype=np.float32)
            # use_choice = False
            if len(np.nonzero(roi_pred)
                   [0]) > 0:  # TODO : 还要做做 class的选择,以及 heatmap采用cv2.add的模式相加
                class_idx = random.choice(np.nonzero(roi_pred)[0])
                AU = config.AU_SQUEEZE[class_idx]
                print(AU)
                choice_activate_map = class_roi_activate_img[
                    class_idx]  # roi_h, roi_w
                activation_color_map = np.round(
                    cmap(choice_activate_map)[:, :, :3] * 255).astype(np.uint8)
                overlay_img = roi_images[
                    box_id] / 2 + activation_color_map.transpose(2, 0, 1) / 2
                overlay_img = np.transpose(overlay_img,
                                           (1, 2, 0)).astype(np.uint8)
                vis_img = cv2.cvtColor(overlay_img, cv2.COLOR_RGB2BGR)
                cv2.imshow("new", vis_img)
                cv2.waitKey(0)
コード例 #18
0
def build_graph_roi_single_label(faster_rcnn, reader_func, output_dir,
                                 database_name, force_generate, proc_num,
                                 cut: bool, extract_key, train_subject,
                                 test_subject):
    '''
    currently CRF can only deal with single label situation
    so use /home/machen/dataset/BP4D/label_dict.txt to regard combine label as new single label
    example(each file contains one video!):
    node_id kown_label features
    1_12 +1 np_file:/path/to/npy features:1,3,4,5,5,...
    node_id specific: ${frame}_${roi}, eg: 1_12
    or
    444 +[0,0,0,1,0,1,0] np_file:/path/to/npy features:1,3,4,5,5,...
    spatio can have two factor node here, for example spatio_1 means upper face, and spatio_2 means lower face relation
    #edge 143 4289 spatio_1
    #edge 143 4289 spatio_2
    #edge 112 1392 temporal

    mode: RNN or CRF
    '''
    adaptive_AU_database(database_name)
    adaptive_AU_relation(database_name)
    au_couple_dict = get_zip_ROI_AU(
    )  # value is AU couple tuple, each tuple denotes an RoI
    # max_au_couple_len = max(len(couple) for couple in au_couple_dict.values())  # we use itertools.product instead
    label_bin_len = config.BOX_NUM[
        database_name]  # each box/ROI only have 1 or 0
    au_couple_set = set(au_couple_dict.values())
    au_couple_list = list(au_couple_set)
    au_couple_list.append(("1", "2", "5", "7"))  # because it is symmetric area
    is_binary_AU = True

    for video_info, subject_id in reader_func(
            output_dir,
            is_binary_AU=is_binary_AU,
            is_need_adaptive_AU_relation=False,
            force_generate=force_generate,
            proc_num=proc_num,
            cut=cut,
            train_subject=train_subject):

        extracted_feature_cache = dict(
        )  # key = np.ndarray_hash , value = h. speed up
        frame_box_cache = dict()  # key = frame, value = boxes
        frame_labels_cache = dict()
        frame_AU_couple_bbox_dict_cache = dict()
        # each video file is copying multiple version but differ in label
        if database_name == "BP4D":
            label_split_list = config.BP4D_LABEL_SPLIT
        elif database_name == "DISFA":
            label_split_list = config.DISFA_LABEL_SPLIT
        for couples_tuple in label_split_list:  # couples_tuple = ("1","3","5",.."4") cross AU_couple, config.LABEL_SPLIT come from frequent pattern statistics
            assert len(couples_tuple) == config.BOX_NUM[database_name]
            couples_tuple = tuple(map(str, sorted(map(int, couples_tuple))))
            couples_tuple_set = set(
                couples_tuple)  # use cartesian product to iterator over
            if len(couples_tuple_set) < len(couples_tuple):
                continue
            # limit too many combination
            # count = 0
            # for fp in fp_set:
            #     inter_set = couples_tuple_set & set(fp)
            #     union_set = couples_tuple_set | set(fp)
            #     iou = len(inter_set) / len(union_set)
            #     if iou > 0.6:
            #         count += 1
            # if count < 20:
            #     continue

            node_list = []
            temporal_edges = []
            spatio_edges = []
            h_info_array = []
            box_geometry_array = []
            for entry_dict in video_info:
                frame = entry_dict["frame"]
                cropped_face = entry_dict["cropped_face"]
                print("processing frame:{}".format(frame))
                all_couple_mask_dict = entry_dict[
                    "all_couple_mask_dict"]  # key is AU couple tuple,不管脸上有没有该AU都返回回来
                image_labels = entry_dict[
                    "all_labels"]  # each region has a label(binary or AU)

                bboxes = []
                labels = []
                AU_couple_bbox_dict = OrderedDict()

                if frame in frame_box_cache:
                    bboxes = frame_box_cache[frame]
                    labels = frame_labels_cache[frame]
                    AU_couple_bbox_dict = frame_AU_couple_bbox_dict_cache[
                        frame]
                else:

                    for idx, (AU_couple, mask) in enumerate(
                            all_couple_mask_dict.items()
                    ):  # We cannot sort this dict here, because region_label depend on order of this dict.AU may contain single_true AU or AU binary tuple (depends on need_adaptive_AU_relation)
                        region_label = image_labels[
                            idx]  # str or tuple, so all_labels index must be the same as all_couple_mask_dict
                        connect_arr = cv2.connectedComponents(mask,
                                                              connectivity=8,
                                                              ltype=cv2.CV_32S)
                        component_num = connect_arr[0]
                        label_matrix = connect_arr[1]
                        for component_label in range(1, component_num):
                            row_col = list(
                                zip(*np.where(
                                    label_matrix == component_label)))
                            row_col = np.array(row_col)
                            y_min_index = np.argmin(row_col[:, 0])
                            y_min = row_col[y_min_index, 0]
                            x_min_index = np.argmin(row_col[:, 1])
                            x_min = row_col[x_min_index, 1]
                            y_max_index = np.argmax(row_col[:, 0])
                            y_max = row_col[y_max_index, 0]
                            x_max_index = np.argmax(row_col[:, 1])
                            x_max = row_col[x_max_index, 1]
                            # same region may be shared by different AU, we must deal with it
                            coordinates = (y_min, x_min, y_max, x_max)

                            if y_min == y_max and x_min == x_max:
                                continue

                            if coordinates not in bboxes:
                                bboxes.append(
                                    coordinates
                                )  # bboxes and labels have the same order
                                labels.append(
                                    region_label
                                )  # AU may contain single_true AU or AU binary tuple (depends on need_adaptive_AU_relation)
                                AU_couple_bbox_dict[coordinates] = AU_couple

                        del label_matrix
                    if len(bboxes) != config.BOX_NUM[database_name]:
                        print("boxes num != {0}, real box num= {1}".format(
                            config.BOX_NUM[database_name], len(bboxes)))
                        continue
                frame_box_cache[frame] = bboxes
                frame_AU_couple_bbox_dict_cache[frame] = AU_couple_bbox_dict
                frame_labels_cache[frame] = labels
                box_idx_AU_dict = dict(
                )  # box_idx => AU, cannot cache! because couples_tuple each time is different
                already_added_AU_set = set()
                for box_idx, _ in enumerate(bboxes):  # bboxes may from cache
                    AU_couple = list(AU_couple_bbox_dict.values())[
                        box_idx]  # AU_couple_bbox_dict may from cache
                    for AU in couples_tuple:  # couples_tuple not from cache, thus change after each iteration 每轮迭代完的时候变换
                        if AU in AU_couple and AU not in already_added_AU_set:
                            box_idx_AU_dict[box_idx] = (AU, AU_couple)
                            already_added_AU_set.add(AU)
                            break

                cropped_face.flags.writeable = False
                key = hash(cropped_face.data.tobytes())
                if key in extracted_feature_cache:
                    h = extracted_feature_cache[key]
                else:
                    with chainer.no_backprop_mode(), chainer.using_config(
                            'train', False):
                        h = faster_rcnn.extract(
                            cropped_face, bboxes,
                            layer=extract_key)  # shape = R' x 2048
                        extracted_feature_cache[key] = h
                    assert h.shape[0] == len(bboxes)
                h = chainer.cuda.to_cpu(h)
                h = h.reshape(len(bboxes), -1)

                # 这个indent级别都是同一张图片内部
                # print("box number, all_mask:", len(bboxes),len(all_couple_mask_dict))
                assert len(box_idx_AU_dict) == config.BOX_NUM[database_name]
                for box_idx, (AU,
                              AU_couple) in sorted(box_idx_AU_dict.items(),
                                                   key=lambda e: int(e[0])):
                    label = np.zeros(
                        shape=label_bin_len, dtype=np.int32
                    )  # bin length became box number > AU_couple number
                    AU_squeeze_idx = config.AU_SQUEEZE.inv[AU]
                    label[couples_tuple.index(AU)] = labels[box_idx][
                        AU_squeeze_idx]  # labels缓存起来可能出错 # labels[box_idx] = 0,0,1,1,...,0  but we want only look at specific idx
                    label = tuple(label)
                    label_arr = np.char.mod("%d", label)
                    label = "({})".format(",".join(label_arr))
                    h_flat = h[box_idx]
                    node_id = "{0}_{1}".format(frame, box_idx)
                    node_list.append(
                        "{0} {1} feature_idx:{2} AU_couple:{3} AU:{4}".format(
                            node_id, label, len(h_info_array), AU_couple, AU))
                    h_info_array.append(h_flat)
                    box_geometry_array.append(bboxes[box_idx])

                # 同一张画面两两组合,看有没连接线,注意AU=0,就是未出现的AU动作的区域也参与连接
                for box_idx_a, box_idx_b in map(
                        sorted, itertools.combinations(range(len(bboxes)), 2)):
                    node_id_a = "{0}_{1}".format(frame, box_idx_a)
                    node_id_b = "{0}_{1}".format(frame, box_idx_b)
                    AU_couple_a = AU_couple_bbox_dict[bboxes[
                        box_idx_a]]  # AU couple represent region( maybe symmetry in face)
                    AU_couple_b = AU_couple_bbox_dict[bboxes[box_idx_b]]
                    if AU_couple_a == AU_couple_b or has_edge(
                            AU_couple_a, AU_couple_b, database_name):
                        spatio_edges.append("#edge {0} {1} spatio".format(
                            node_id_a, node_id_b))

            box_id_temporal_dict = defaultdict(
                list)  # key = roi/bbox id, value = node_id list cross temporal
            for node_info in node_list:
                node_id = node_info[0:node_info.index(" ")]
                box_id = node_id[node_id.index("_") + 1:]
                box_id_temporal_dict[box_id].append(node_id)

            for node_id_list in box_id_temporal_dict.values():
                for idx, node_id in enumerate(node_id_list):
                    if idx + 1 < len(node_id_list):
                        node_id_next = node_id_list[idx + 1]
                        temporal_edges.append("#edge {0} {1} temporal".format(
                            node_id, node_id_next))
            train_AU_out_path = "{0}/train/{1}/{2}.txt".format(
                output_dir, "_".join(map(str, couples_tuple)),
                video_info[0]["video_id"])
            test_AU_out_path = "{0}/test/{1}/{2}.txt".format(
                output_dir, "_".join(map(str, couples_tuple)),
                video_info[0]["video_id"])
            if subject_id in train_subject:
                output_path = train_AU_out_path
                npz_path = output_dir + os.sep + "train" + os.sep + os.path.basename(
                    output_path)[:os.path.basename(output_path).
                                 rindex(".")] + ".npz"
            elif subject_id in test_subject:
                output_path = test_AU_out_path
                npz_path = output_dir + os.sep + "test" + os.sep + os.path.basename(
                    output_path)[:os.path.basename(output_path).
                                 rindex(".")] + ".npz"
            os.makedirs(os.path.dirname(output_path), exist_ok=True)

            if not os.path.exists(npz_path):
                np.savez(npz_path,
                         appearance_features=h_info_array,
                         geometry_features=np.array(box_geometry_array,
                                                    dtype=np.float32))
            with open(output_path, "w") as file_obj:
                for line in node_list:
                    file_obj.write("{}\n".format(line))
                for line in spatio_edges:
                    file_obj.write("{}\n".format(line))
                for line in temporal_edges:
                    file_obj.write("{}\n".format(line))
                file_obj.flush()
                node_list.clear()
                spatio_edges.clear()
                temporal_edges.clear()
                h_info_array.clear()
コード例 #19
0
def read_DISFA_video_label(output_dir,
                           is_binary_AU,
                           is_need_adaptive_AU_relation=False,
                           force_generate=True,
                           proc_num=10,
                           cut=False,
                           train_subject=None):
    mgr = mp.Manager()
    queue = mgr.Queue(maxsize=20000)
    for orientation in ["Left", "Right"]:
        if is_need_adaptive_AU_relation:
            adaptive_AU_relation(
            )  # delete AU relation pair occur in same facial region
        au_couple_dict = get_zip_ROI_AU()
        au_couple_child_dict = get_AU_couple_child(au_couple_dict)
        DISFA_base_dir = config.DATA_PATH["DISFA"]
        label_file_dir = DISFA_base_dir + "/ActionUnit_Labels/"
        img_folder = DISFA_base_dir + "/Img_{}Camera".format(orientation)
        for video_name in os.listdir(label_file_dir):  # each file is a video
            is_train = True if video_name in train_subject else False
            if not force_generate:
                prefix = "train" if is_train else "test"
                target_file_path = output_dir + os.sep + prefix + os.sep + video_name + "_" + orientation + ".npz"
                if os.path.exists(target_file_path):
                    continue
            resultdict = {}
            if proc_num > 1:
                pool = mp.Pool(processes=proc_num)
                procs = 0
                one_file_name = os.listdir(label_file_dir + os.sep +
                                           video_name)[0]
                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        one_file_name, "r") as file_obj:
                    for idx, line in enumerate(file_obj):
                        line = line.strip()
                        if line:
                            frame = line.split(",")[0]
                            img_path = img_folder + "/{0}/{1}.jpg".format(
                                video_name, frame)

                            pool.apply_async(func=delegate_mask_crop,
                                             args=(img_path, True, queue))
                            procs += 1
                for i in range(procs):
                    try:
                        entry = queue.get(block=True, timeout=60)
                        resultdict[entry[0]] = (entry[1], entry[2])
                    except Exception:
                        print("queue block time out")
                        break
                pool.close()
                pool.join()
                del pool
            else:  # only one process
                one_file_name = os.listdir(label_file_dir + os.sep +
                                           video_name)[0]
                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        one_file_name, "r") as file_obj:
                    for idx, line in enumerate(file_obj):
                        line = line.strip()
                        if line:
                            frame = line.split(",")[0]
                            img_path = img_folder + "/{0}/{1}.jpg".format(
                                video_name, frame)
                            try:
                                cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
                                    img_path, True)
                                resultdict[img_path] = (cropped_face,
                                                        AU_mask_dict)
                            except IndexError:
                                pass

            frame_label = dict()
            video_info = []
            video_img_path_set = set()
            for file_name in os.listdir(
                    label_file_dir + os.sep +
                    video_name):  # each file is one AU ( video file )
                AU = file_name[file_name.index("au") + 2:file_name.rindex(".")]

                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        file_name, "r") as file_obj:
                    for line in file_obj:
                        frame = int(line.split(",")[0])
                        AU_intensity = int(line.split(",")[1])
                        img_path = img_folder + "/{0}/{1}.jpg".format(
                            video_name, frame)
                        video_img_path_set.add((frame, img_path))
                        if frame not in frame_label:
                            frame_label[frame] = set()
                        if AU_intensity >= 1:  # NOTE that we apply AU_intensity >= 1
                            frame_label[int(frame)].add(AU)  #存储str类型
            for frame, img_path in sorted(video_img_path_set,
                                          key=lambda e: int(e[0])):
                if img_path not in resultdict:
                    continue
                AU_set = frame_label[frame]  # it is whole image's AU set

                if cut and len(AU_set) == 0:
                    continue
                cropped_face, AU_mask_dict = resultdict[img_path]

                all_couple_mask_dict = OrderedDict()
                for AU in sorted(map(
                        int, config.AU_ROI.keys())):  # ensure same order
                    all_couple_mask_dict[au_couple_dict[str(
                        AU)]] = AU_mask_dict[str(AU)]

                all_labels = list()  # 开始拼接all_labels
                for AU_couple in all_couple_mask_dict.keys(
                ):  # 顺序与all_couple_mask_dict一致
                    child_AU_couple_list = au_couple_child_dict[AU_couple]
                    AU_couple = set(AU_couple)
                    for child_AU_couple in child_AU_couple_list:
                        AU_couple.update(
                            child_AU_couple)  # combine child region's AU
                    if not is_binary_AU:  # in CRF, CRF模式需要将同一个区域的多个AU用逗号分隔,拼接
                        concat_AU = []
                        for AU in AU_couple:
                            if AU in AU_set:  # AU_set 存储真实AU(ground truth label):str类型
                                concat_AU.append(AU)

                        if len(concat_AU) == 0:
                            all_labels.append(
                                "0")  # 若该区域压根没有任何AU出现,为了让只支持单label的CRF工作,用0来代替
                        else:
                            all_labels.append(",".join(sorted(concat_AU)))

                    else:  # convert to np.array which is AU_bin
                        AU_bin = np.zeros(len(config.AU_SQUEEZE)).astype(
                            np.uint8)
                        for AU in AU_couple:
                            if AU in AU_set:  # judge if this region contain which subset of whole image's AU_set
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], 1)
                        all_labels.append(tuple(AU_bin))

                video_info.append({
                    "frame": frame,
                    "cropped_face": cropped_face,
                    "all_couple_mask_dict": all_couple_mask_dict,
                    "all_labels": all_labels,
                    "video_id": video_name + "_" + orientation
                })
            resultdict.clear()
            if video_info:
                yield video_info, video_name
            else:
                print(
                    "error in file:{} no video found".format(video_name + "_" +
                                                             orientation))
コード例 #20
0
def read_BP4D_video_label(output_dir,
                          is_binary_AU,
                          is_need_adaptive_AU_relation=False,
                          force_generate=True,
                          proc_num=10,
                          cut=False,
                          train_subject=None):
    '''
    :param
            output_dir : 用于检查如果目标的文件已经存在,那么就不再生成
            is_binary_AU:
                          True --> return AU_binary 01010100
                          False --> used for CRF mode: single true AU label CRF/ or AU combination separate by comma
    :yield:  每个视频video收集齐了yield回去,视频中每一帧返回3部分:
            1. "img_path": /path/to/image
            1."all_couple_mask_dict": 是OrderedDict,包含所有区域的mask,不管AU是不是+1,还是-1(不管AU出现没出现),key是AU_couple,来自于au_couple_dict = get_zip_ROI_AU()
            2."labels": 是list,index与all_couple_mask_dict一致,其中每个label
               要么是binary形式01010110,
               要么是3,4(由于一块位置可以发生多个AU,因此可以用逗号隔开的字符串来返回),根据is_binary_AU返回不同的值
    '''
    mgr = mp.Manager()
    queue = mgr.Queue(maxsize=20000)

    if is_need_adaptive_AU_relation:
        adaptive_AU_relation(
        )  # delete AU relation pair occur in same facial region
    au_couple_dict = get_zip_ROI_AU()
    au_couple_child_dict = get_AU_couple_child(
        au_couple_dict)  # AU_couple => list of child AU_couple
    # if need_translate_combine_AU ==> "mask_path_dict":{(2,3,4): /pathtomask} convert to "mask_path_dict":{110: /pathtomask}
    # each is dict : {"img": /path/to/img, "mask_path_dict":{(2,3,4): /pathtomask}, }
    BP4D_base_dir_path = config.DATA_PATH["BP4D"]
    label_file_dir = BP4D_base_dir_path + "/AUCoding/"

    for file_name in os.listdir(label_file_dir):  # each file is a video

        subject_name = file_name[:file_name.index("_")]
        sequence_name = file_name[file_name.index("_") +
                                  1:file_name.rindex(".")]
        is_train = True if subject_name in train_subject else False
        if not force_generate:
            prefix = "train" if is_train else "test"
            target_file_path = output_dir + os.sep + prefix + os.sep + subject_name + "_" + sequence_name + ".npz"
            if os.path.exists(target_file_path):
                continue
        resultdict = {}
        if proc_num > 1:

            one_image_path = os.listdir(config.RGB_PATH["BP4D"] + os.sep +
                                        subject_name + os.sep +
                                        sequence_name)[0]
            zfill_len = len(one_image_path[:one_image_path.rindex(".")])

            procs = 0
            # read image file and crop and get AU mask
            pool = mp.Pool(processes=proc_num)
            with open(label_file_dir + "/" + file_name,
                      "r") as au_file_obj:  # each file is a video
                for idx, line in enumerate(au_file_obj):

                    if idx == 0:
                        continue
                    lines = line.split(",")
                    frame = lines[0].zfill(zfill_len)

                    img_path = config.RGB_PATH[
                        "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                    if not os.path.exists(img_path):
                        print("not exists img_path:{}".format(img_path))
                        continue

                    pool.apply_async(func=delegate_mask_crop,
                                     args=(img_path, True, queue))
                    procs += 1
                    # p = mp.Process(target=delegate_mask_crop, args=(img_path, True, queue))
                    # procs.append(p)
                    # p.start()

            for i in range(procs):
                try:
                    entry = queue.get(block=True, timeout=360)
                    resultdict[entry[0]] = (entry[1], entry[2])
                except Exception:
                    print("queue block time out")
                    break
            pool.close()
            pool.join()
            del pool
        else:  # only one process
            one_image_path = os.listdir(config.RGB_PATH["BP4D"] + os.sep +
                                        subject_name + os.sep +
                                        sequence_name)[0]
            zfill_len = len(one_image_path[:one_image_path.rindex(".")])
            with open(label_file_dir + "/" + file_name,
                      "r") as au_file_obj:  # each file is a video
                for idx, line in enumerate(au_file_obj):

                    lines = line.split(",")
                    frame = lines[0].zfill(zfill_len)

                    img_path = config.RGB_PATH[
                        "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                    if not os.path.exists(img_path):
                        print("not exists img_path:{}".format(img_path))
                        continue
                    try:
                        cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
                            img_path, channel_first=True)
                        # note that above AU_mask_dict, each AU may have mask that contains multiple separate regions
                        resultdict[img_path] = (cropped_face, AU_mask_dict)
                        print("one image :{} done".format(img_path))
                    except IndexError:
                        print("img_path:{} cannot obtain 68 landmark".format(
                            img_path))
                        pass
        # for p in procs:
        #     p.join()
        AU_column_idx = {}
        with open(label_file_dir + "/" + file_name,
                  "r") as au_file_obj:  # each file is a video

            video_info = []
            for idx, line in enumerate(
                    au_file_obj):  # each line represent a frame image

                line = line.rstrip()
                lines = line.split(",")
                if idx == 0:  # header define which column is which Action Unit
                    for col_idx, AU in enumerate(lines[1:]):
                        AU_column_idx[AU] = col_idx + 1  # read header
                    continue  # read head over , continue

                frame = lines[0].zfill(zfill_len)

                img_path = config.RGB_PATH[
                    "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                if not os.path.exists(img_path):
                    print("not exists img_path:{}".format(img_path))
                    continue
                if img_path not in resultdict:
                    print("img_path:{} landmark not found, continue".format(
                        img_path))
                    continue
                cropped_face, AU_mask_dict = resultdict[img_path]

                all_couple_mask_dict = OrderedDict()
                for AU in sorted(map(
                        int, config.AU_ROI.keys())):  # ensure same order
                    all_couple_mask_dict[au_couple_dict[str(
                        AU)]] = AU_mask_dict[str(AU)]

                au_label_dict = {
                    AU: int(lines[AU_column_idx[AU]])
                    for AU in config.AU_ROI.keys()
                }  # store real AU label
                if cut and all(_au_label == 0
                               for _au_label in au_label_dict.values()):
                    continue
                all_labels = list()  # 开始拼接all_labels
                for AU_couple in all_couple_mask_dict.keys(
                ):  # 顺序与all_couple_mask_dict一致
                    child_AU_couple_list = au_couple_child_dict[AU_couple]
                    AU_couple = set(AU_couple)
                    for child_AU_couple in child_AU_couple_list:
                        AU_couple.update(
                            child_AU_couple
                        )  # label fetch: combine child region's AU
                    if not is_binary_AU:  # in CRF, CRF模式需要将同一个区域的多个AU用逗号分隔,拼接
                        concat_AU = []
                        for AU in AU_couple:
                            if au_label_dict[AU] == 1:
                                concat_AU.append(AU)
                            elif au_label_dict[AU] == 9:
                                concat_AU.append("?{}".format(AU))

                        if len(concat_AU) == 0:
                            all_labels.append(
                                "0")  # 若该区域压根没有任何AU出现,为了让只支持单label的CRF工作,用0来代替
                        else:
                            all_labels.append(",".join(concat_AU))

                    else:  # convert to np.array which is AU_bin
                        AU_bin = np.zeros(len(config.AU_SQUEEZE)).astype(
                            np.uint8)
                        for AU in AU_couple:
                            if au_label_dict[AU] == 9:
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], -1)
                            elif au_label_dict[AU] == 1:
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], 1)

                        all_labels.append(tuple(AU_bin))

                video_info.append({
                    "frame":
                    frame,
                    "cropped_face":
                    cropped_face,
                    "all_couple_mask_dict":
                    all_couple_mask_dict,
                    "all_labels":
                    all_labels,
                    "video_id":
                    subject_name + "_" + sequence_name
                })
        resultdict.clear()
        if video_info:
            yield video_info, subject_name
        else:
            print("error video_info:{}".format(file_name))
コード例 #21
0
def generate_mask_contain_img(database_name, img_path):
    adaptive_AU_database(database_name)
    mask_color = {}
    for parent_color, child_color in MASK_CONTAIN.items():
        mask_color[color_bgr(parent_color)] = color_bgr(child_color)
    cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
        img_path, channel_first=False)

    AU_couple_dict = get_zip_ROI_AU()
    AU_couple_child = get_AU_couple_child(AU_couple_dict)
    land = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
    landmark, _, _ = land.landmark(image=cropped_face)
    roi_polygons = land.split_ROI(landmark)
    for roi_no, polygon_vertex_arr in roi_polygons.items():
        polygon_vertex_arr[0, :] = np.round(polygon_vertex_arr[0, :])
        polygon_vertex_arr[1, :] = np.round(polygon_vertex_arr[1, :])
        polygon_vertex_arr = sort_clockwise(polygon_vertex_arr.tolist())
        cv2.polylines(cropped_face, [polygon_vertex_arr],
                      True,
                      color_bgr(RED),
                      thickness=1)
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(cropped_face,
                    str(roi_no),
                    tuple(
                        np.mean(polygon_vertex_arr, axis=0).astype(np.int32)),
                    font,
                    0.7, (0, 255, 255),
                    thickness=1)
    already_fill_AU = set()
    gen_face_lst = dict()
    all_child_set = set()
    for child_set in AU_couple_child.values():
        for child in child_set:
            all_child_set.add(child)
    new_face = np.zeros_like(cropped_face)
    for AU in config.AU_ROI.keys():
        AU_couple = AU_couple_dict[AU]
        if AU_couple in all_child_set:
            continue
        if AU_couple in already_fill_AU:
            continue
        already_fill_AU.add(AU_couple)
        mask = AU_mask_dict[AU]
        child_AU_set = AU_couple_child[AU_couple]
        color_parent = list(MASK_CONTAIN.keys())[0]
        color_child = MASK_CONTAIN[color_parent]
        color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)

        color_mask[mask != 0] = color_parent
        # cv2.addWeighted(color_mask,0.5,  color_mask,1-0.5,0,color_mask)
        if np.any(new_face):
            cropped_face = new_face
        cv2.addWeighted(cropped_face, 1, color_mask, 0.3, 0, new_face, -1)

        for child_AU in child_AU_set:
            if child_AU in already_fill_AU:
                continue
            already_fill_AU.add(child_AU)
            mask = AU_mask_dict[child_AU[0]]
            color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
            color_mask[mask != 0] = random.choice(color_child)
            cv2.addWeighted(new_face, 1, color_mask, 0.5, 0, new_face, -1)

    return new_face