コード例 #1
0
    def __init__(self,
                 database,
                 L,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 train_all_data,
                 two_stream_mode,
                 pretrained_target="",
                 paper_report_label_idx=None):
        self.database = database
        self.split_name = split_name
        self.L = L  # used for the optical flow image fetch at before L/2 and after L/2
        self.au_couple_dict = get_zip_ROI_AU()
        self.mc_manager = mc_manager
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.pretrained_target = pretrained_target
        self.two_stream_mode = two_stream_mode
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA
        self.paper_report_label_idx = paper_report_label_idx
        if train_all_data:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{}_fold".format(fold), "full_pretrain.txt")
        else:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{0}_fold".format(fold),
                "id_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    relative_path, au_set_str, _, current_database_name = line.split(
                        "\t")
                    AU_set = set(
                        AU for AU in au_set_str.split(',')
                        if AU in config.AU_ROI and AU in config.AU_SQUEEZE.inv)
                    if au_set_str == "0":
                        AU_set = set()
                    rgb_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + relative_path  # id file 是相对路径
                    flow_path = config.FLOW_PATH[
                        current_database_name] + os.path.sep + relative_path
                    if os.path.exists(rgb_path):
                        self.result_data.append((rgb_path, flow_path, AU_set,
                                                 current_database_name))

        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #2
0
ファイル: AU_dataset.py プロジェクト: zhangxujinsh/AU_R-CNN
    def __init__(self,
                 img_resolution,
                 database,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 prefix="",
                 pretrained_target=""):
        self.database = database
        self.img_resolution = img_resolution
        self.split_name = split_name
        self.au_couple_dict = get_zip_ROI_AU()
        self.mc_manager = mc_manager
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.pretrained_target = pretrained_target
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA

        id_list_file_path = os.path.join(
            self.dir + "/idx/{0}_fold{1}".format(fold, prefix),
            "intensity_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        self.video_offset = OrderedDict()
        self.video_count = defaultdict(int)
        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    img_path, au_set_str, from_img_path, current_database_name = line.split(
                        "\t")
                    AU_intensity = np.fromstring(au_set_str,
                                                 dtype=np.int32,
                                                 sep=',')
                    from_img_path = img_path if from_img_path == "#" else from_img_path
                    img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + img_path  # id file 是相对路径
                    from_img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + from_img_path
                    video_id = "/".join(
                        [img_path.split("/")[-3],
                         img_path.split("/")[-2]])
                    if video_id not in self.video_offset:
                        self.video_offset[video_id] = len(self.result_data)
                    self.video_count[video_id] += 1
                    if os.path.exists(img_path):
                        self.result_data.append(
                            (img_path, from_img_path, AU_intensity,
                             current_database_name))
        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #3
0
    def __init__(self,
                 database,
                 fold,
                 split_name,
                 split_index,
                 mc_manager,
                 train_all_data=False,
                 read_type="rgb",
                 pretrained_target="",
                 img_resolution=config.IMG_SIZE[0]):
        self.database = database
        self.img_resolution = img_resolution
        self.split_name = split_name  # trainval or test
        self.read_type = read_type
        self.au_couple_dict = get_zip_ROI_AU()
        self.au_couple_child_dict = get_AU_couple_child(self.au_couple_dict)
        self.AU_intensity_label = {
        }  # subject + "/" + emotion_seq + "/" + frame => ... not implemented
        self.dir = config.DATA_PATH[database]  # BP4D/DISFA/ BP4D_DISFA
        self.pretrained_target = pretrained_target
        self.mc_manager = mc_manager
        if train_all_data:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{}_fold".format(fold), "full_pretrain.txt")
        else:
            id_list_file_path = os.path.join(
                self.dir + "/idx/{0}_fold".format(fold),
                "id_{0}_{1}.txt".format(split_name, split_index))
        self.result_data = []

        print("idfile:{}".format(id_list_file_path))
        with open(id_list_file_path, "r") as file_obj:
            for idx, line in enumerate(file_obj):
                if line.rstrip():
                    line = line.rstrip()
                    img_path, au_set_str, from_img_path, current_database_name = line.split(
                        "\t")
                    AU_set = set(AU for AU in au_set_str.split(',')
                                 if AU in config.AU_ROI)
                    if au_set_str == "0":
                        AU_set = set()
                    from_img_path = img_path if from_img_path == "#" else from_img_path

                    img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + img_path  # id file 是相对路径
                    from_img_path = config.RGB_PATH[
                        current_database_name] + os.path.sep + from_img_path
                    if os.path.exists(img_path):
                        self.result_data.append(
                            (img_path, from_img_path, AU_set,
                             current_database_name))
        self.result_data.sort(key=lambda entry: (
            entry[0].split("/")[-3], entry[0].split("/")[-2],
            int(entry[0].split("/")[-1][:entry[0].split("/")[-1].rindex(".")]))
                              )
        self._num_examples = len(self.result_data)
        print("read id file done, all examples:{}".format(self._num_examples))
コード例 #4
0
ファイル: demo_AU_rcnn.py プロジェクト: zhangxujinsh/AU_R-CNN
def main():
    parser = argparse.ArgumentParser(
        description='generate Graph desc file script')
    parser.add_argument('--mean',
                        default=config.ROOT_PATH + "BP4D/idx/mean_rgb.npy",
                        help='image mean .npy file')
    parser.add_argument("--image",
                        default='C:/Users/machen/Downloads/tmp/face.jpg')
    parser.add_argument(
        "--model", default="C:/Users/machen/Downloads/tmp/BP4D_3_fold_1.npz")
    parser.add_argument("--pretrained_model_name",
                        '-premodel',
                        default='resnet101')
    parser.add_argument('--database', default='BP4D', help='Output directory')
    parser.add_argument('--device',
                        default=0,
                        type=int,
                        help='GPU device number')
    args = parser.parse_args()
    adaptive_AU_database(args.database)

    if args.pretrained_model_name == "resnet101":
        faster_rcnn = FasterRCNNResnet101(
            n_fg_class=len(config.AU_SQUEEZE),
            pretrained_model="resnet101",
            mean_file=args.mean,
            use_lstm=False,
            extract_len=1000
        )  # 可改为/home/machen/face_expr/result/snapshot_model.npz
    elif args.pretrained_model_name == "vgg":
        faster_rcnn = FasterRCNNVGG16(n_fg_class=len(config.AU_SQUEEZE),
                                      pretrained_model="imagenet",
                                      mean_file=args.mean,
                                      use_lstm=False,
                                      extract_len=1000)

    if os.path.exists(args.model):
        print("loading pretrained snapshot:{}".format(args.model))
        chainer.serializers.load_npz(args.model, faster_rcnn)
    if args.device >= 0:
        faster_rcnn.to_gpu(args.device)
        chainer.cuda.get_device_from_id(int(args.device)).use()

    heatmap_gen = HeatMapGenerator(np.load(args.model), use_relu=True)
    if args.device >= 0:
        heatmap_gen.to_gpu(args.device)
    cropped_face, AU_box_dict = FaceMaskCropper.get_cropface_and_box(
        args.image, args.image, channel_first=True)
    au_couple_dict = get_zip_ROI_AU()
    au_couple_child = get_AU_couple_child(
        au_couple_dict)  # AU couple tuple => child fetch list
    au_couple_box = dict()  # value is box (4 tuple coordinate) list

    for AU, AU_couple in au_couple_dict.items():
        au_couple_box[AU_couple] = AU_box_dict[AU]
    box_lst = []
    roi_no_AU_couple_dict = dict()
    roi_no = 0
    for AU_couple, couple_box_lst in au_couple_box.items():
        box_lst.extend(couple_box_lst)
        for _ in couple_box_lst:
            roi_no_AU_couple_dict[roi_no] = AU_couple
            roi_no += 1

    box_lst = np.asarray(box_lst)
    cropped_face = cropped_face.astype(np.float32)
    orig_face = cropped_face
    cropped_face = faster_rcnn.prepare(
        cropped_face)  # substract mean pixel value
    box_lst = box_lst.astype(np.float32)
    orig_box_lst = box_lst
    batch = [
        (cropped_face, box_lst),
    ]
    cropped_face, box_lst = concat_examples(
        batch, args.device)  # N,3, H, W, ;  N, F, 4

    if box_lst.shape[1] != config.BOX_NUM[args.database]:
        print("error box num {0} != {1}".format(box_lst.shape[1],
                                                config.BOX_NUM[args.database]))
        return
    with chainer.no_backprop_mode(), chainer.using_config("train", False):
        cropped_face = chainer.Variable(cropped_face)
        box_lst = chainer.Variable(box_lst)
        roi_preds, _ = faster_rcnn.predict(cropped_face, box_lst)  # R, 22
        roi_feature_maps = faster_rcnn.extract(orig_face, orig_box_lst,
                                               'res5')  # R, 2048 7,7

        roi_images = []
        box_lst = box_lst[0].data.astype(np.int32)
        for box in box_lst:
            y_min, x_min, y_max, x_max = box
            roi_image = orig_face[:, y_min:y_max + 1,
                                  x_min:x_max + 1]  # N, 3, roi_H, roi_W
            roi_images.append(roi_image)  # list of  N, 3, roi_H, roi_W
        cmap = plt.get_cmap('jet')
        # image_activate_map = np.zeros((cropped_face.shape[2], cropped_face.shape[3]), dtype=np.float32)
        for box_id, (roi_image, roi_feature_map) in enumerate(
                zip(roi_images, roi_feature_maps)):
            y_min, x_min, y_max, x_max = box_lst[box_id]
            # 22, roi_h, roi_w, 3
            xp = chainer.cuda.get_array_module(roi_feature_map)
            roi_feature_map = xp.expand_dims(roi_feature_map, 0)
            #   class_roi_overlay_img = 22, roi_h, roi_w
            class_roi_activate_img = heatmap_gen.generate_activate_roi_map(
                roi_feature_map, (y_max - y_min + 1, x_max - x_min + 1))
            roi_pred = roi_preds[box_id]  # 22
            # choice_activate_map = np.zeros((y_max-y_min+1, x_max-x_min+1), dtype=np.float32)
            # use_choice = False
            if len(np.nonzero(roi_pred)
                   [0]) > 0:  # TODO : 还要做做 class的选择,以及 heatmap采用cv2.add的模式相加
                class_idx = random.choice(np.nonzero(roi_pred)[0])
                AU = config.AU_SQUEEZE[class_idx]
                print(AU)
                choice_activate_map = class_roi_activate_img[
                    class_idx]  # roi_h, roi_w
                activation_color_map = np.round(
                    cmap(choice_activate_map)[:, :, :3] * 255).astype(np.uint8)
                overlay_img = roi_images[
                    box_id] / 2 + activation_color_map.transpose(2, 0, 1) / 2
                overlay_img = np.transpose(overlay_img,
                                           (1, 2, 0)).astype(np.uint8)
                vis_img = cv2.cvtColor(overlay_img, cv2.COLOR_RGB2BGR)
                cv2.imshow("new", vis_img)
                cv2.waitKey(0)
コード例 #5
0
def read_DISFA_video_label(output_dir,
                           is_binary_AU,
                           is_need_adaptive_AU_relation=False,
                           force_generate=True,
                           proc_num=10,
                           cut=False,
                           train_subject=None):
    mgr = mp.Manager()
    queue = mgr.Queue(maxsize=20000)
    for orientation in ["Left", "Right"]:
        if is_need_adaptive_AU_relation:
            adaptive_AU_relation(
            )  # delete AU relation pair occur in same facial region
        au_couple_dict = get_zip_ROI_AU()
        au_couple_child_dict = get_AU_couple_child(au_couple_dict)
        DISFA_base_dir = config.DATA_PATH["DISFA"]
        label_file_dir = DISFA_base_dir + "/ActionUnit_Labels/"
        img_folder = DISFA_base_dir + "/Img_{}Camera".format(orientation)
        for video_name in os.listdir(label_file_dir):  # each file is a video
            is_train = True if video_name in train_subject else False
            if not force_generate:
                prefix = "train" if is_train else "test"
                target_file_path = output_dir + os.sep + prefix + os.sep + video_name + "_" + orientation + ".npz"
                if os.path.exists(target_file_path):
                    continue
            resultdict = {}
            if proc_num > 1:
                pool = mp.Pool(processes=proc_num)
                procs = 0
                one_file_name = os.listdir(label_file_dir + os.sep +
                                           video_name)[0]
                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        one_file_name, "r") as file_obj:
                    for idx, line in enumerate(file_obj):
                        line = line.strip()
                        if line:
                            frame = line.split(",")[0]
                            img_path = img_folder + "/{0}/{1}.jpg".format(
                                video_name, frame)

                            pool.apply_async(func=delegate_mask_crop,
                                             args=(img_path, True, queue))
                            procs += 1
                for i in range(procs):
                    try:
                        entry = queue.get(block=True, timeout=60)
                        resultdict[entry[0]] = (entry[1], entry[2])
                    except Exception:
                        print("queue block time out")
                        break
                pool.close()
                pool.join()
                del pool
            else:  # only one process
                one_file_name = os.listdir(label_file_dir + os.sep +
                                           video_name)[0]
                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        one_file_name, "r") as file_obj:
                    for idx, line in enumerate(file_obj):
                        line = line.strip()
                        if line:
                            frame = line.split(",")[0]
                            img_path = img_folder + "/{0}/{1}.jpg".format(
                                video_name, frame)
                            try:
                                cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
                                    img_path, True)
                                resultdict[img_path] = (cropped_face,
                                                        AU_mask_dict)
                            except IndexError:
                                pass

            frame_label = dict()
            video_info = []
            video_img_path_set = set()
            for file_name in os.listdir(
                    label_file_dir + os.sep +
                    video_name):  # each file is one AU ( video file )
                AU = file_name[file_name.index("au") + 2:file_name.rindex(".")]

                with open(
                        label_file_dir + os.sep + video_name + os.sep +
                        file_name, "r") as file_obj:
                    for line in file_obj:
                        frame = int(line.split(",")[0])
                        AU_intensity = int(line.split(",")[1])
                        img_path = img_folder + "/{0}/{1}.jpg".format(
                            video_name, frame)
                        video_img_path_set.add((frame, img_path))
                        if frame not in frame_label:
                            frame_label[frame] = set()
                        if AU_intensity >= 1:  # NOTE that we apply AU_intensity >= 1
                            frame_label[int(frame)].add(AU)  #存储str类型
            for frame, img_path in sorted(video_img_path_set,
                                          key=lambda e: int(e[0])):
                if img_path not in resultdict:
                    continue
                AU_set = frame_label[frame]  # it is whole image's AU set

                if cut and len(AU_set) == 0:
                    continue
                cropped_face, AU_mask_dict = resultdict[img_path]

                all_couple_mask_dict = OrderedDict()
                for AU in sorted(map(
                        int, config.AU_ROI.keys())):  # ensure same order
                    all_couple_mask_dict[au_couple_dict[str(
                        AU)]] = AU_mask_dict[str(AU)]

                all_labels = list()  # 开始拼接all_labels
                for AU_couple in all_couple_mask_dict.keys(
                ):  # 顺序与all_couple_mask_dict一致
                    child_AU_couple_list = au_couple_child_dict[AU_couple]
                    AU_couple = set(AU_couple)
                    for child_AU_couple in child_AU_couple_list:
                        AU_couple.update(
                            child_AU_couple)  # combine child region's AU
                    if not is_binary_AU:  # in CRF, CRF模式需要将同一个区域的多个AU用逗号分隔,拼接
                        concat_AU = []
                        for AU in AU_couple:
                            if AU in AU_set:  # AU_set 存储真实AU(ground truth label):str类型
                                concat_AU.append(AU)

                        if len(concat_AU) == 0:
                            all_labels.append(
                                "0")  # 若该区域压根没有任何AU出现,为了让只支持单label的CRF工作,用0来代替
                        else:
                            all_labels.append(",".join(sorted(concat_AU)))

                    else:  # convert to np.array which is AU_bin
                        AU_bin = np.zeros(len(config.AU_SQUEEZE)).astype(
                            np.uint8)
                        for AU in AU_couple:
                            if AU in AU_set:  # judge if this region contain which subset of whole image's AU_set
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], 1)
                        all_labels.append(tuple(AU_bin))

                video_info.append({
                    "frame": frame,
                    "cropped_face": cropped_face,
                    "all_couple_mask_dict": all_couple_mask_dict,
                    "all_labels": all_labels,
                    "video_id": video_name + "_" + orientation
                })
            resultdict.clear()
            if video_info:
                yield video_info, video_name
            else:
                print(
                    "error in file:{} no video found".format(video_name + "_" +
                                                             orientation))
コード例 #6
0
def read_BP4D_video_label(output_dir,
                          is_binary_AU,
                          is_need_adaptive_AU_relation=False,
                          force_generate=True,
                          proc_num=10,
                          cut=False,
                          train_subject=None):
    '''
    :param
            output_dir : 用于检查如果目标的文件已经存在,那么就不再生成
            is_binary_AU:
                          True --> return AU_binary 01010100
                          False --> used for CRF mode: single true AU label CRF/ or AU combination separate by comma
    :yield:  每个视频video收集齐了yield回去,视频中每一帧返回3部分:
            1. "img_path": /path/to/image
            1."all_couple_mask_dict": 是OrderedDict,包含所有区域的mask,不管AU是不是+1,还是-1(不管AU出现没出现),key是AU_couple,来自于au_couple_dict = get_zip_ROI_AU()
            2."labels": 是list,index与all_couple_mask_dict一致,其中每个label
               要么是binary形式01010110,
               要么是3,4(由于一块位置可以发生多个AU,因此可以用逗号隔开的字符串来返回),根据is_binary_AU返回不同的值
    '''
    mgr = mp.Manager()
    queue = mgr.Queue(maxsize=20000)

    if is_need_adaptive_AU_relation:
        adaptive_AU_relation(
        )  # delete AU relation pair occur in same facial region
    au_couple_dict = get_zip_ROI_AU()
    au_couple_child_dict = get_AU_couple_child(
        au_couple_dict)  # AU_couple => list of child AU_couple
    # if need_translate_combine_AU ==> "mask_path_dict":{(2,3,4): /pathtomask} convert to "mask_path_dict":{110: /pathtomask}
    # each is dict : {"img": /path/to/img, "mask_path_dict":{(2,3,4): /pathtomask}, }
    BP4D_base_dir_path = config.DATA_PATH["BP4D"]
    label_file_dir = BP4D_base_dir_path + "/AUCoding/"

    for file_name in os.listdir(label_file_dir):  # each file is a video

        subject_name = file_name[:file_name.index("_")]
        sequence_name = file_name[file_name.index("_") +
                                  1:file_name.rindex(".")]
        is_train = True if subject_name in train_subject else False
        if not force_generate:
            prefix = "train" if is_train else "test"
            target_file_path = output_dir + os.sep + prefix + os.sep + subject_name + "_" + sequence_name + ".npz"
            if os.path.exists(target_file_path):
                continue
        resultdict = {}
        if proc_num > 1:

            one_image_path = os.listdir(config.RGB_PATH["BP4D"] + os.sep +
                                        subject_name + os.sep +
                                        sequence_name)[0]
            zfill_len = len(one_image_path[:one_image_path.rindex(".")])

            procs = 0
            # read image file and crop and get AU mask
            pool = mp.Pool(processes=proc_num)
            with open(label_file_dir + "/" + file_name,
                      "r") as au_file_obj:  # each file is a video
                for idx, line in enumerate(au_file_obj):

                    if idx == 0:
                        continue
                    lines = line.split(",")
                    frame = lines[0].zfill(zfill_len)

                    img_path = config.RGB_PATH[
                        "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                    if not os.path.exists(img_path):
                        print("not exists img_path:{}".format(img_path))
                        continue

                    pool.apply_async(func=delegate_mask_crop,
                                     args=(img_path, True, queue))
                    procs += 1
                    # p = mp.Process(target=delegate_mask_crop, args=(img_path, True, queue))
                    # procs.append(p)
                    # p.start()

            for i in range(procs):
                try:
                    entry = queue.get(block=True, timeout=360)
                    resultdict[entry[0]] = (entry[1], entry[2])
                except Exception:
                    print("queue block time out")
                    break
            pool.close()
            pool.join()
            del pool
        else:  # only one process
            one_image_path = os.listdir(config.RGB_PATH["BP4D"] + os.sep +
                                        subject_name + os.sep +
                                        sequence_name)[0]
            zfill_len = len(one_image_path[:one_image_path.rindex(".")])
            with open(label_file_dir + "/" + file_name,
                      "r") as au_file_obj:  # each file is a video
                for idx, line in enumerate(au_file_obj):

                    lines = line.split(",")
                    frame = lines[0].zfill(zfill_len)

                    img_path = config.RGB_PATH[
                        "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                    if not os.path.exists(img_path):
                        print("not exists img_path:{}".format(img_path))
                        continue
                    try:
                        cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
                            img_path, channel_first=True)
                        # note that above AU_mask_dict, each AU may have mask that contains multiple separate regions
                        resultdict[img_path] = (cropped_face, AU_mask_dict)
                        print("one image :{} done".format(img_path))
                    except IndexError:
                        print("img_path:{} cannot obtain 68 landmark".format(
                            img_path))
                        pass
        # for p in procs:
        #     p.join()
        AU_column_idx = {}
        with open(label_file_dir + "/" + file_name,
                  "r") as au_file_obj:  # each file is a video

            video_info = []
            for idx, line in enumerate(
                    au_file_obj):  # each line represent a frame image

                line = line.rstrip()
                lines = line.split(",")
                if idx == 0:  # header define which column is which Action Unit
                    for col_idx, AU in enumerate(lines[1:]):
                        AU_column_idx[AU] = col_idx + 1  # read header
                    continue  # read head over , continue

                frame = lines[0].zfill(zfill_len)

                img_path = config.RGB_PATH[
                    "BP4D"] + os.sep + subject_name + os.sep + sequence_name + os.sep + frame + ".jpg"
                if not os.path.exists(img_path):
                    print("not exists img_path:{}".format(img_path))
                    continue
                if img_path not in resultdict:
                    print("img_path:{} landmark not found, continue".format(
                        img_path))
                    continue
                cropped_face, AU_mask_dict = resultdict[img_path]

                all_couple_mask_dict = OrderedDict()
                for AU in sorted(map(
                        int, config.AU_ROI.keys())):  # ensure same order
                    all_couple_mask_dict[au_couple_dict[str(
                        AU)]] = AU_mask_dict[str(AU)]

                au_label_dict = {
                    AU: int(lines[AU_column_idx[AU]])
                    for AU in config.AU_ROI.keys()
                }  # store real AU label
                if cut and all(_au_label == 0
                               for _au_label in au_label_dict.values()):
                    continue
                all_labels = list()  # 开始拼接all_labels
                for AU_couple in all_couple_mask_dict.keys(
                ):  # 顺序与all_couple_mask_dict一致
                    child_AU_couple_list = au_couple_child_dict[AU_couple]
                    AU_couple = set(AU_couple)
                    for child_AU_couple in child_AU_couple_list:
                        AU_couple.update(
                            child_AU_couple
                        )  # label fetch: combine child region's AU
                    if not is_binary_AU:  # in CRF, CRF模式需要将同一个区域的多个AU用逗号分隔,拼接
                        concat_AU = []
                        for AU in AU_couple:
                            if au_label_dict[AU] == 1:
                                concat_AU.append(AU)
                            elif au_label_dict[AU] == 9:
                                concat_AU.append("?{}".format(AU))

                        if len(concat_AU) == 0:
                            all_labels.append(
                                "0")  # 若该区域压根没有任何AU出现,为了让只支持单label的CRF工作,用0来代替
                        else:
                            all_labels.append(",".join(concat_AU))

                    else:  # convert to np.array which is AU_bin
                        AU_bin = np.zeros(len(config.AU_SQUEEZE)).astype(
                            np.uint8)
                        for AU in AU_couple:
                            if au_label_dict[AU] == 9:
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], -1)
                            elif au_label_dict[AU] == 1:
                                np.put(AU_bin, config.AU_SQUEEZE.inv[AU], 1)

                        all_labels.append(tuple(AU_bin))

                video_info.append({
                    "frame":
                    frame,
                    "cropped_face":
                    cropped_face,
                    "all_couple_mask_dict":
                    all_couple_mask_dict,
                    "all_labels":
                    all_labels,
                    "video_id":
                    subject_name + "_" + sequence_name
                })
        resultdict.clear()
        if video_info:
            yield video_info, subject_name
        else:
            print("error video_info:{}".format(file_name))
コード例 #7
0
def generate_mask_contain_img(database_name, img_path):
    adaptive_AU_database(database_name)
    mask_color = {}
    for parent_color, child_color in MASK_CONTAIN.items():
        mask_color[color_bgr(parent_color)] = color_bgr(child_color)
    cropped_face, AU_mask_dict = FaceMaskCropper.get_cropface_and_mask(
        img_path, channel_first=False)

    AU_couple_dict = get_zip_ROI_AU()
    AU_couple_child = get_AU_couple_child(AU_couple_dict)
    land = FaceLandMark(config.DLIB_LANDMARK_PRETRAIN)
    landmark, _, _ = land.landmark(image=cropped_face)
    roi_polygons = land.split_ROI(landmark)
    for roi_no, polygon_vertex_arr in roi_polygons.items():
        polygon_vertex_arr[0, :] = np.round(polygon_vertex_arr[0, :])
        polygon_vertex_arr[1, :] = np.round(polygon_vertex_arr[1, :])
        polygon_vertex_arr = sort_clockwise(polygon_vertex_arr.tolist())
        cv2.polylines(cropped_face, [polygon_vertex_arr],
                      True,
                      color_bgr(RED),
                      thickness=1)
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(cropped_face,
                    str(roi_no),
                    tuple(
                        np.mean(polygon_vertex_arr, axis=0).astype(np.int32)),
                    font,
                    0.7, (0, 255, 255),
                    thickness=1)
    already_fill_AU = set()
    gen_face_lst = dict()
    all_child_set = set()
    for child_set in AU_couple_child.values():
        for child in child_set:
            all_child_set.add(child)
    new_face = np.zeros_like(cropped_face)
    for AU in config.AU_ROI.keys():
        AU_couple = AU_couple_dict[AU]
        if AU_couple in all_child_set:
            continue
        if AU_couple in already_fill_AU:
            continue
        already_fill_AU.add(AU_couple)
        mask = AU_mask_dict[AU]
        child_AU_set = AU_couple_child[AU_couple]
        color_parent = list(MASK_CONTAIN.keys())[0]
        color_child = MASK_CONTAIN[color_parent]
        color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)

        color_mask[mask != 0] = color_parent
        # cv2.addWeighted(color_mask,0.5,  color_mask,1-0.5,0,color_mask)
        if np.any(new_face):
            cropped_face = new_face
        cv2.addWeighted(cropped_face, 1, color_mask, 0.3, 0, new_face, -1)

        for child_AU in child_AU_set:
            if child_AU in already_fill_AU:
                continue
            already_fill_AU.add(child_AU)
            mask = AU_mask_dict[child_AU[0]]
            color_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
            color_mask[mask != 0] = random.choice(color_child)
            cv2.addWeighted(new_face, 1, color_mask, 0.5, 0, new_face, -1)

    return new_face