Esempio n. 1
0
def light_track(pose_estimator, image_folder, output_json_path,
                visualize_folder, output_video_path, gt_info):
    global total_time_POSE_ESTIMATOR, total_time_POSE_SIMILARITY, total_time_DET, total_time_ALL, total_time_ASSOCIATE
    global video_name, iou_alpha1, pose_alpha1
    global filter_bbox_number, total_num_FRAMES, total_num_PERSONS, total_num_VIDEOS
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()
    ### hyper-papermet
    keypoints_number = 15
    interval = 5

    bbox_dets_list_list = []
    keypoints_list_list = []

    num_imgs = len(gt_info)

    first_img_id = 0

    start_from_labeled = False
    if start_from_labeled:
        first_img_id = find_first_labeled_opensvai_json(gt_info)

    next_id = 0  # track_id 从0开始算
    img_id = first_img_id
    total_num_FRAMES += num_imgs

    gt_frame_index_list = find_gt_frame_index_list(gt_info, interval=interval)
    while img_id < num_imgs:
        ## loop Initialization
        img_gt_info = gt_info[img_id]
        image_name, labeled, candidates_info = read_image_data_opensvai_json(
            img_gt_info)
        img_path = os.path.join(image_folder, image_name)

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty
        prev_frame_img_id = max(0, img_id - first_img_id - 1)

        # 假如第一帧是gt帧,那么直接复制gt的结果,放到list_list中
        if start_from_labeled and img_id == first_img_id:
            num_dets = len(candidates_info)
            for det_id in range(num_dets):
                track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                    candidates_info, det_id)
                # first帧直接使用
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "imgpath": img_path,
                    "track_id": track_id,
                    "bbox": bbox_det
                }
                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "imgpath": img_path,
                    "track_id": track_id,
                    "keypoints": keypoints
                }
                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
                next_id = max(next_id, track_id)
                next_id += 1
            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)
        else:
            #### 持续跟踪,当img_id是gt帧的时候会将gt和预测的进行比较.
            logger.info("Tracing,img_id:{}".format(img_id))
            candidates_total = []
            st_time_DET = time.time()
            candidates_from_detector = inference_yolov3(img_path)
            end_time_DET = time.time()
            total_time_DET += (end_time_DET - st_time_DET)

            candidates_from_prev = []

            bbox_list_prev_frame = []
            ''' 根据先前帧的信息补充框 '''
            if img_id > first_img_id:
                bbox_list_prev_frame = bbox_dets_list_list[
                    prev_frame_img_id].copy()
                keypoints_list_prev_frame = keypoints_list_list[
                    prev_frame_img_id].copy()
                num_prev_bbox = len(bbox_list_prev_frame)
                for prev_det_id in range(num_prev_bbox):
                    # obtain bbox position and track id
                    keypoints = keypoints_list_prev_frame[prev_det_id][
                        'keypoints']
                    bbox_det_next = get_bbox_from_keypoints(keypoints)
                    if bbox_invalid(bbox_det_next):
                        continue
                    # xywh
                    candidates_from_prev.append(bbox_det_next)
            ''' 拿到本帧全部的候选框 '''
            candidates_total = candidates_from_detector + candidates_from_prev
            num_candidate = len(candidates_total)
            ''' 使用关节点的置信度来作为bbox的置信度 '''
            candidates_dets = []
            for candidate_id in range(num_candidate):
                bbox_det = candidates_total[candidate_id]
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "bbox": bbox_det
                }
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator,
                                                bbox_det_dict)[0]['keypoints']
                end_time_pose = time.time()
                total_time_POSE_ESTIMATOR += (end_time_pose - st_time_pose)
                bbox_det_next = xywh_to_x1y1x2y2(bbox_det)
                score = sum(keypoints[2::3]) / keypoints_number
                # 不知道为什么他这个pose的置信度会高于1
                if bbox_invalid(bbox_det_next) or score < 0.7:
                    filter_bbox_number += 1
                    continue
                candidate_det = bbox_det_next + [score]
                candidates_dets.append(candidate_det)
                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "keypoints": keypoints
                }

                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            # 根据bbox的置信度来使用nms
            keep = py_cpu_nms(np.array(candidates_dets, dtype=np.float32),
                              0.5) if len(candidates_dets) > 0 else []

            candidates_total = np.array(candidates_total)[keep]
            t = bbox_dets_list.copy()
            k = keypoints_list.copy()
            # 筛选过后的
            bbox_dets_list = [t[i] for i in keep]
            keypoints_list = [k[i] for i in keep]
            """ Data association """
            cur_det_number = len(candidates_total)
            prev_det_number = len(bbox_list_prev_frame)
            if img_id == first_img_id or prev_det_number == 0:
                for det_id, bbox_det_dict in enumerate(bbox_dets_list):
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det_dict['det_id'] = det_id
                    keypoints_dict['det_id'] = det_id
                    track_id = next_id
                    bbox_det_dict['track_id'] = track_id
                    keypoints_dict['track_id'] = track_id
                    next_id = max(next_id, track_id)
                    next_id += 1
            else:
                scores = np.zeros((cur_det_number, prev_det_number))
                for det_id in range(cur_det_number):
                    bbox_det_dict = bbox_dets_list[det_id]
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det = bbox_det_dict['bbox']
                    keypoints = keypoints_dict['keypoints']

                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(prev_det_number):
                        prev_bbox_det_dict = bbox_list_prev_frame[prev_det_id]
                        prev_keypoints_dict = keypoints_list_prev_frame[
                            prev_det_id]
                        iou_score = iou(bbox_det,
                                        prev_bbox_det_dict['bbox'],
                                        xyxy=False)
                        if iou_score > 0.5:
                            scores[det_id,
                                   prev_det_id] = iou_alpha1 * iou_score

                st_time_ass = time.time()
                bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(
                    bbox_dets_list, bbox_list_prev_frame, scores,
                    keypoints_list, next_id)
                end_time_ass = time.time()
                total_time_ASSOCIATE += (end_time_ass - st_time_ass)

                next_id = now_next_id

            if len(bbox_dets_list) == 0:
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)
            ##########################################
            #### 如果是gt帧则会与预测帧的结果进行比较 ####
            ##########################################
            if img_id in gt_frame_index_list and gt_frame_index_list.index(
                    img_id) >= 1:
                logger.info("type:{},img_id:{}".format('gt_guide', img_id))
                # gt frame
                num_dets = len(candidates_info)

                bbox_list_prediction = bbox_dets_list_list[
                    img_id - first_img_id].copy()
                keypoints_list_prediction = keypoints_list_list[
                    img_id - first_img_id].copy()
                bbox_list_gt = []
                keypoints_list_gt = []
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                        candidates_info, det_id)
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "bbox": bbox_det
                    }
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "keypoints": keypoints
                    }

                    bbox_list_gt.append(bbox_det_dict)
                    keypoints_list_gt.append(keypoints_dict)
                bbox_dets_list_list[img_id - first_img_id] = bbox_list_gt
                keypoints_list_list[img_id - first_img_id] = keypoints_list_gt
                need_correct = distance_between_gt_prediction(
                    gt_dict={
                        "det": bbox_list_gt,
                        "keypoints": keypoints_list_gt
                    },
                    predict_dict={
                        "det": bbox_list_prediction,
                        "keypoints": keypoints_list_prediction
                    })
                if need_correct:
                    ## 往前进行矫正
                    correct_index = img_id - 1
                    correct_end_index = img_id - int(interval / 2)
                    # 从后往前
                    while correct_index >= correct_end_index:
                        ## 假设框是对的,id错了
                        ## 此时的prev_det_number 是gt
                        bbox_dets_list = bbox_dets_list_list[correct_index -
                                                             first_img_id]
                        keypoints_list = keypoints_list_list[correct_index -
                                                             first_img_id]

                        prev_det_number = len(bbox_list_gt)
                        cur_det_number = len(bbox_dets_list)
                        # prev 是已完成匹配的,cur是待匹配的
                        scores = np.zeros((cur_det_number, prev_det_number))
                        for det_id in range(cur_det_number):
                            bbox_det_dict = bbox_dets_list[det_id]
                            keypoints_dict = keypoints_list[det_id]
                            bbox_det = bbox_det_dict['bbox']
                            keypoints = keypoints_dict['keypoints']

                            # 计算当前帧的bbox和先前帧bboxes的分数
                            for prev_det_id in range(prev_det_number):
                                bbox_det_dict_gt = bbox_list_gt[prev_det_id]
                                iou_score = iou(bbox_det,
                                                bbox_det_dict_gt['bbox'],
                                                xyxy=False)
                                if iou_score > 0.2:
                                    scores[
                                        det_id,
                                        prev_det_id] = iou_alpha1 * iou_score

                        if prev_det_number > 0 and cur_det_number > 0:
                            bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(
                                bbox_dets_list, bbox_list_gt, scores,
                                keypoints_list, next_id)

                        # 这一帧没有一个保留下来的bbox
                        if len(bbox_dets_list) == 0:
                            bbox_det_dict = {
                                "img_id": img_id,
                                "det_id": 0,
                                "track_id": None,
                                "imgpath": img_path,
                                "bbox": [0, 0, 2, 2]
                            }
                            bbox_dets_list.append(bbox_det_dict)

                            keypoints_dict = {
                                "img_id": img_id,
                                "det_id": 0,
                                "track_id": None,
                                "imgpath": img_path,
                                "keypoints": []
                            }
                            keypoints_list.append(keypoints_dict)
                        bbox_dets_list_list[
                            correct_index -
                            first_img_id] = bbox_dets_list.copy()
                        keypoints_list_list[
                            correct_index -
                            first_img_id] = keypoints_list.copy()
                        correct_index -= 1

        img_id += 1
    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list,
                                          bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path,
                                               classes,
                                               joint_pairs,
                                               joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")

        fps = 5  # 25 原来
        visualizer.make_video_from_images(img_paths,
                                          output_video_path,
                                          fps=fps,
                                          size=None,
                                          is_color=True,
                                          format="XVID")
Esempio n. 2
0
def light_track(pose_estimator,
                image_folder, output_json_path,
                visualize_folder, output_video_path):
    global total_time_POSE, total_time_DET, total_time_ALL, total_num_FRAMES, total_num_PERSONS
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    # process the frames sequentially
    keypoints_list = []
    bbox_dets_list = []
    # frame_prev = -1
    # frame_cur = 0
    img_id = -1
    next_id = 0
    bbox_dets_list_list = []
    keypoints_list_list = []

    flag_mandatory_keyframe = False

    img_paths = get_immediate_childfile_paths(image_folder)
    num_imgs = len(img_paths)
    total_num_FRAMES = num_imgs

    # 有gt的的bbox
    gt_bbox_img_id_list = [0]

    while img_id < num_imgs - 1:
        img_id += 1
        img_path = img_paths[img_id]
        print("Current tracking: [image_id:{}]".format(img_id))
        frame_cur = img_id

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty

        if img_id in gt_bbox_img_id_list:
            # 当前帧是gt帧
            # 当做好数据处理后,要用gt来做,现在是伪gt
            ##  TODO 带数据弄好后 remove
            human_candidates = inference_yolov3(img_path)  # 拿到bbox
            num_dets = len(human_candidates)
            # 检测bbox的keypoints
            for det_id in range(num_dets):
                bbox_det = human_candidates[det_id]
                bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, enlarge_scale)
                bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                # update current frame bbox
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": det_id,
                                 "imgpath": img_path,
                                 "track_id": det_id,
                                 "bbox": bbox_det}

                # keypoint检测,并记录时间
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]["keypoints"]
                end_time_pose = time.time()
                total_time_POSE += (end_time_pose - st_time_pose)

                keypoints_dict = {"img_id": img_id,
                                  "det_id": det_id,
                                  "imgpath": img_path,
                                  "track_id": det_id,
                                  "keypoints": keypoints}
                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            # assert len(bbox_dets_list) == 2
            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)
        else:
            # 当前帧非gt帧
            # perform detection at keyframes
            st_time_detection = time.time()
            # human_candidates  ( center_x,center_y,w,h)
            human_candidates, confidence_scores = inference_yolov3_v1(img_path)  # 拿到bbox

            end_time_detection = time.time()
            total_time_DET += (end_time_detection - st_time_detection)

            num_dets = len(human_candidates)
            print("Keyframe: {} detections".format(num_dets))

            # if nothing detected at this frame
            if num_dets <= 0:
                ## TODO
                break

            # 检测bbox的keypoints
            for det_id in range(num_dets):
                bbox_det = human_candidates[det_id]
                bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, enlarge_scale)
                bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                # update current frame bbox
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": det_id,
                                 "imgpath": img_path,
                                 "track_id": None,
                                 "bbox": bbox_det}

                # keypoint检测,并记录时间
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]["keypoints"]
                end_time_pose = time.time()
                total_time_POSE += (end_time_pose - st_time_pose)

                keypoints_dict = {"img_id": img_id,
                                  "det_id": det_id,
                                  "imgpath": img_path,
                                  "track_id": None,
                                  "keypoints": keypoints}
                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)

            # 拿到上一帧的信息
            bbox_list_prev_frame = bbox_dets_list_list[img_id - 1].copy()
            keypoints_list_prev_frame = keypoints_list_list[img_id - 1].copy()

            ############ 裁剪
            if img_id in [34, 35, 36, 37, 38]:
                cnt = 0
                for bbox_info in bbox_list_prev_frame:
                    bbox_det = bbox_info['bbox']
                    image_path = bbox_info['imgpath']
                    frame_name = os.path.basename(image_path)
                    frame_name = frame_name.split('.')[0]
                    video_name = os.path.basename(image_folder)
                    image = cv2.imread(image_path)
                    bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                    bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, 0.1)
                    bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                    x1, y1, w, h = max(int(bbox_det[0]), 0), max(int(bbox_det[1]), 0), bbox_det[2], bbox_det[3]
                    ### 得到裁剪后的图
                    cropped_image = image[y1:(y1 + h), x1:(x1 + w)]
                    create_folder(os.path.join(image_crop_output_path, video_name))
                    cropped_image_path = os.path.join(image_crop_output_path, video_name,
                                                      '{}-{:0>3d}.jpg'.format(frame_name, cnt))
                    cv2.imwrite(cropped_image_path, cropped_image)
                    ### 找bbox
                    crop_human_candidates, _ = inference_yolov3_v1(cropped_image_path)
                    for det_id in range(len(crop_human_candidates)):
                        bbox_det = crop_human_candidates[det_id]
                        ### 画bbox
                        # cropped_bbox_image = visualizer.draw_bbox_from_python_data(cropped_image, bbox_det)
                        cropped_bbox_image = cv2.rectangle(cropped_image.copy(), (int(bbox_det[0]), int(bbox_det[1])),
                                                           (int(bbox_det[0] + bbox_det[2]),
                                                            int(bbox_det[1] + bbox_det[3])),
                                                           (255, 0, 255), thickness=3)
                        cropped_image_bbox_path = os.path.join(image_crop_output_path, video_name,
                                                               '{}-{:0>3d}-{:0>3d}.jpg'.format(frame_name, cnt, det_id))
                        cv2.imwrite(cropped_image_bbox_path, cropped_bbox_image)
                    cnt += 1

            ##############

            num_bbox_prev_frame = len(bbox_list_prev_frame)

            # 获取到三个指标的信息
            confidence_scores = np.array(confidence_scores)
            confidence_scores = confidence_scores[:, np.newaxis]
            pose_matching_scores = np.zeros([num_dets, num_bbox_prev_frame], dtype=float)
            iou_scores = np.ones([num_dets, num_bbox_prev_frame], dtype=float)
            prev_track_ids = []
            for bbox_prev_index in range(num_bbox_prev_frame):
                # 上一帧中包含的trackIds
                track_id = keypoints_list_prev_frame[bbox_prev_index]["track_id"]
                prev_track_ids.append(track_id)
            for det_id in range(num_dets):
                for bbox_prev_index in range(num_bbox_prev_frame):
                    keypoints_cur_frame = keypoints_list[det_id]["keypoints"]
                    bbox_cur_frame = bbox_dets_list[det_id]["bbox"]

                    keypoints_prev_frame = keypoints_list_prev_frame[bbox_prev_index]["keypoints"]
                    bbox_prev_frame = bbox_list_prev_frame[bbox_prev_index]["bbox"]
                    # get pose match score
                    pose_matching_scores[det_id, bbox_prev_index] = get_pose_matching_score(keypoints_cur_frame,
                                                                                            keypoints_prev_frame,
                                                                                            bbox_cur_frame,
                                                                                            bbox_prev_frame)

                    # get bbox distance score
                    iou_scores[det_id, bbox_prev_index] = iou(bbox_cur_frame, bbox_prev_frame, xyxy=False)

            ###########################
            ## 根据指标来选择当前帧的框 ##
            ###########################
            bbox_dets_list, keypoints_list = select_bbox_by_criterion(bbox_dets_list, keypoints_list, confidence_scores,
                                                                      pose_matching_scores, iou_scores, prev_track_ids)
            print("Final save bbox number: {} ".format(len(bbox_dets_list)))
            print("image path:{}".format(img_path))
            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

            # ##############################
            # ## update bbox information ##
            # ##############################
            # temp_bbox_dets_list = []  ## temp 临时存储
            # temp_keypoints_list = []  ## temp 临时存储
            # for index, bbox_save_index in enumerate(bbox_save_index_list):
            #     if bbox_save_index == None:
            #         continue
            #     bbox_det_dict['track_id'] = index
            #     bbox_det_dict = bbox_dets_list[bbox_save_index]
            #     keypoints_dict = keypoints_list[bbox_save_index]
            #     bbox_det_dict['det_id'] = index
            #     keypoints_dict['track_id'] = index
            #     keypoints_dict['det_id'] = index
            #
            #     temp_bbox_dets_list.append(bbox_det_dict)
            #     temp_keypoints_list.append(keypoints_dict)

    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list, bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path, classes, joint_pairs, joint_names, image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")
        visualizer.make_video_from_images(img_paths, output_video_path, fps=25, size=None, is_color=True,
                                          format="XVID")
def light_track(pose_estimator, image_folder, output_json_path,
                visualize_folder, output_video_path, gt_info):
    global total_time_POSE, total_time_DET, total_time_ALL, total_num_FRAMES, total_num_PERSONS
    global video_name
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    next_id = 1
    bbox_dets_list_list = []
    keypoints_list_list = []

    num_imgs = len(gt_info)
    total_num_FRAMES = num_imgs
    first_img_id = 0

    start_from_labeled = False
    if start_from_labeled:
        first_img_id = find_first_labeled_opensvai_json(gt_info)

    img_id = first_img_id
    """  之后的数据关联如何做?

    TODO  
    相邻两帧之间的检测框匹配,权值使用 total score = IOU score + Pose similarity, 之后再使用匈牙利算法进行二分图的匹配。

    大致思路:
        1.先算出距离相似度和pose相似度,将两者相加作为 框之间的联系。
        2.过滤低置信度的框。
        3.二分图匹配。
    """
    iou_alpha1 = 1
    pose_alpha1 = -1.3  # 求的是pose差异值,差异值越小表示越越相似。
    while img_id < num_imgs:

        img_gt_info = gt_info[img_id]
        image_name, labeled, candidates_info = read_image_data_opensvai_json(
            img_gt_info)
        img_path = os.path.join(image_folder, image_name)

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty
        prev_frame_img_id = max(0, img_id - first_img_id - 1)
        if labeled and (img_id - first_img_id) % 5 == 0:
            logger.info("type:{},img_id:{}".format('gt', img_id))
            # gt frame

            num_dets = len(candidates_info)

            if img_id == first_img_id:
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                        candidates_info, det_id)
                    # first帧直接使用
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "bbox": bbox_det
                    }
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "keypoints": keypoints
                    }
                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)
                    next_id = max(next_id, track_id)
                    next_id += 1
            else:  # Not First Frame
                bbox_list_prev_frame = bbox_dets_list_list[
                    prev_frame_img_id].copy()
                keypoints_list_prev_frame = keypoints_list_list[
                    prev_frame_img_id].copy()
                scores = np.empty((num_dets, len(keypoints_list_prev_frame)))
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                        candidates_info, det_id)
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": None,
                        "bbox": bbox_det
                    }
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": None,
                        "keypoints": keypoints
                    }
                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(len(keypoints_list_prev_frame)):
                        prev_bbox_det_dict = bbox_list_prev_frame[prev_det_id]
                        prev_keypoints_dict = keypoints_list_prev_frame[
                            prev_det_id]
                        iou_score = iou(bbox_det,
                                        prev_bbox_det_dict['bbox'],
                                        xyxy=False)
                        if iou_score > 0.5:
                            pose_match_score = get_pose_matching_score(
                                keypoints, prev_keypoints_dict["keypoints"],
                                bbox_det_dict['bbox'],
                                prev_bbox_det_dict['bbox'])
                            scores[
                                det_id,
                                prev_det_id] = iou_alpha1 * iou_score + pose_alpha1 * pose_match_score

                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)

                bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(
                    bbox_dets_list, keypoints_list_prev_frame, scores,
                    keypoints_list, next_id)
                next_id = now_next_id + 1

            # 这一帧没有一个保留下来的bbox
            if len(bbox_dets_list) == 0:
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        else:
            logger.info("type:{},img_id:{}".format('normal', img_id))
            ''' NOT GT Frame '''
            candidates_total = []
            candidates_from_detector = inference_yolov3(img_path)

            candidates_from_prev = []

            bbox_list_prev_frame = []
            ''' 根据先前帧的信息补充框 '''
            if img_id > first_img_id:
                bbox_list_prev_frame = bbox_dets_list_list[
                    prev_frame_img_id].copy()
                keypoints_list_prev_frame = keypoints_list_list[
                    prev_frame_img_id].copy()
                num_prev_bbox = len(bbox_list_prev_frame)
                for prev_det_id in range(num_prev_bbox):
                    # obtain bbox position and track id
                    keypoints = keypoints_list_prev_frame[prev_det_id][
                        'keypoints']
                    bbox_det_next = get_bbox_from_keypoints(keypoints)
                    if bbox_invalid(bbox_det_next):
                        continue
                    # xywh
                    candidates_from_prev.append(bbox_det_next)
            ''' 拿到本帧全部的候选框 '''
            candidates_total = candidates_from_detector + candidates_from_prev
            num_candidate = len(candidates_total)
            ''' 使用关节点的置信度来作为bbox的置信度 '''
            candidates_dets = []
            for candidate_id in range(num_candidate):
                bbox_det = candidates_total[candidate_id]
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "bbox": bbox_det
                }
                keypoints = inference_keypoints(pose_estimator,
                                                bbox_det_dict)[0]['keypoints']

                bbox_det_next = xywh_to_x1y1x2y2(bbox_det)
                score = sum(keypoints[2::3]) / 25
                if bbox_invalid(bbox_det_next) or score < 0.33:
                    continue
                candidate_det = bbox_det_next + [score]
                candidates_dets.append(candidate_det)
                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "keypoints": keypoints
                }

                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            # 根据bbox的置信度来使用nms
            keep = py_cpu_nms(np.array(candidates_dets, dtype=np.float32),
                              0.5) if len(candidates_dets) > 0 else []

            candidates_total = np.array(candidates_total)[keep]
            t = bbox_dets_list.copy()
            k = keypoints_list.copy()
            # 筛选过后的
            bbox_dets_list = [t[i] for i in keep]
            keypoints_list = [k[i] for i in keep]
            """ Data association """
            cur_det_number = len(candidates_total)
            prev_det_number = len(bbox_list_prev_frame)
            if img_id == first_img_id or prev_det_number == 0:
                for det_id, bbox_det_dict in enumerate(bbox_dets_list):
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det_dict['det_id'] = det_id
                    keypoints_dict['det_id'] = det_id
                    track_id = next_id
                    bbox_det_dict['track_id'] = track_id
                    keypoints_dict['track_id'] = track_id
                    next_id = max(next_id, track_id)
                    next_id += 1
            else:
                scores = np.zeros((cur_det_number, prev_det_number))
                for det_id in range(cur_det_number):
                    bbox_det_dict = bbox_dets_list[det_id]
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det = bbox_det_dict['bbox']
                    keypoints = keypoints_dict['keypoints']

                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(prev_det_number):
                        prev_bbox_det_dict = bbox_list_prev_frame[prev_det_id]
                        prev_keypoints_dict = keypoints_list_prev_frame[
                            prev_det_id]
                        iou_score = iou(bbox_det,
                                        prev_bbox_det_dict['bbox'],
                                        xyxy=False)
                        if iou_score > 0.5:
                            pose_match_score = get_pose_matching_score(
                                keypoints, prev_keypoints_dict["keypoints"],
                                bbox_det_dict["bbox"],
                                prev_bbox_det_dict["bbox"])
                            scores[
                                det_id,
                                prev_det_id] = iou_alpha1 * iou_score + pose_alpha1 * pose_match_score

                bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(
                    bbox_dets_list, bbox_list_prev_frame, scores,
                    keypoints_list, next_id)
                next_id = now_next_id + 1

            if len(bbox_dets_list) == 0:
                img_id += 1
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        img_id += 1
    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list,
                                          bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path,
                                               classes,
                                               joint_pairs,
                                               joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")

        fps = 5  # 25 原来
        visualizer.make_video_from_images(img_paths,
                                          output_video_path,
                                          fps=fps,
                                          size=None,
                                          is_color=True,
                                          format="XVID")
def light_track(pose_estimator,
                image_folder, output_json_path,
                visualize_folder, output_video_path, gt_info):
    global total_time_POSE_ESTIMATOR, total_time_POSE_SIMILARITY, total_time_DET, total_time_ALL, total_time_ASSOCIATE
    global video_name, iou_alpha1, pose_alpha1
    global filter_bbox_number, total_num_FRAMES, total_num_PERSONS, total_num_VIDEOS
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    bbox_dets_list_list = []
    keypoints_list_list = []

    num_imgs = len(gt_info)

    first_img_id = 0

    start_from_labeled = False
    if start_from_labeled:
        first_img_id = find_first_labeled_opensvai_json(gt_info)

    # last_gt_img_id = find_last_labeled_opensvai_json(gt_info)
    # num_imgs = last_gt_img_id + 1
    next_id = 0  # track_id 从0开始算
    img_id = first_img_id
    keypoints_number = 15
    total_num_FRAMES = num_imgs

    while img_id < num_imgs:

        img_gt_info = gt_info[img_id]
        image_name, labeled, candidates_info = read_image_data_opensvai_json(img_gt_info)
        img_path = os.path.join(image_folder, image_name)

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty
        prev_frame_img_id = max(0, img_id - first_img_id - 1)
        if labeled and (img_id - first_img_id) % 5 == 0:
            logger.info("type:{},img_id:{}".format('gt', img_id))
            # gt frame

            num_dets = len(candidates_info)

            if img_id == first_img_id:
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(candidates_info, det_id)
                    # first帧直接使用
                    bbox_det_dict = {"img_id": img_id,
                                     "det_id": det_id,
                                     "imgpath": img_path,
                                     "track_id": track_id,
                                     "bbox": bbox_det}
                    keypoints_dict = {"img_id": img_id,
                                      "det_id": det_id,
                                      "imgpath": img_path,
                                      "track_id": track_id,
                                      "keypoints": keypoints}
                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)
                    next_id = max(next_id, track_id)
                    next_id += 1
            else:  # Not First Frame
                bbox_list_prev_frame = bbox_dets_list_list[prev_frame_img_id].copy()
                keypoints_list_prev_frame = keypoints_list_list[prev_frame_img_id].copy()
                scores = np.zeros((num_dets, len(keypoints_list_prev_frame)))
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(candidates_info, det_id)
                    bbox_det_dict = {"img_id": img_id,
                                     "det_id": det_id,
                                     "imgpath": img_path,
                                     "track_id": None,
                                     "bbox": bbox_det}
                    keypoints_dict = {"img_id": img_id,
                                      "det_id": det_id,
                                      "imgpath": img_path,
                                      "track_id": None,
                                      "keypoints": keypoints}
                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(len(keypoints_list_prev_frame)):
                        prev_bbox_det_dict = bbox_list_prev_frame[prev_det_id]
                        prev_keypoints_dict = keypoints_list_prev_frame[prev_det_id]
                        iou_score = iou(bbox_det, prev_bbox_det_dict['bbox'], xyxy=False)
                        if iou_score > 0.5:
                            st_time_pose = time.time()
                            # gt的点标的并不全,没有标注的数据c为0
                            prev_keypoints = prev_keypoints_dict["keypoints"].copy()
                            for index, value in enumerate(keypoints[2::3]):
                                if value == 0:
                                    prev_keypoints[index * 3:(index + 1) * 3] = 0, 0, 0

                            pose_match_score = get_pose_matching_score(keypoints, prev_keypoints,
                                                                       bbox_det_dict['bbox'],
                                                                       prev_bbox_det_dict['bbox'])
                            end_time_pose = time.time()
                            total_time_POSE_SIMILARITY += (end_time_pose - st_time_pose)
                            scores[det_id, prev_det_id] = iou_alpha1 * iou_score + pose_alpha1 * pose_match_score

                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)
                st_time_ass = time.time()
                bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(bbox_dets_list,
                                                                                       keypoints_list_prev_frame,
                                                                                       scores, keypoints_list, next_id)
                end_time_ass = time.time()
                total_time_ASSOCIATE += (end_time_ass - st_time_ass)

                next_id = now_next_id

            # 这一帧没有一个保留下来的bbox
            if len(bbox_dets_list) == 0:
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": 0,
                                 "track_id": None,
                                 "imgpath": img_path,
                                 "bbox": [0, 0, 2, 2]}
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {"img_id": img_id,
                                  "det_id": 0,
                                  "track_id": None,
                                  "imgpath": img_path,
                                  "keypoints": []}
                keypoints_list.append(keypoints_dict)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        else:
            logger.info("type:{},img_id:{}".format('normal', img_id))
            ''' NOT GT Frame '''
            candidates_total = []
            st_time_DET = time.time()
            candidates_from_detector = inference_yolov3(img_path)
            end_time_DET = time.time()
            total_time_DET += (end_time_DET - st_time_DET)

            candidates_from_prev = []

            bbox_list_prev_frame = []
            ''' 根据先前帧的信息补充框 '''
            if img_id > first_img_id:
                bbox_list_prev_frame = bbox_dets_list_list[prev_frame_img_id].copy()
                keypoints_list_prev_frame = keypoints_list_list[prev_frame_img_id].copy()
                num_prev_bbox = len(bbox_list_prev_frame)
                for prev_det_id in range(num_prev_bbox):
                    # obtain bbox position and track id
                    keypoints = keypoints_list_prev_frame[prev_det_id]['keypoints']
                    bbox_det_next = get_bbox_from_keypoints(keypoints)
                    if bbox_invalid(bbox_det_next):
                        continue
                    # xywh
                    candidates_from_prev.append(bbox_det_next)

            ''' 拿到本帧全部的候选框 '''
            candidates_total = candidates_from_detector + candidates_from_prev
            num_candidate = len(candidates_total)
            ''' 使用关节点的置信度来作为bbox的置信度 '''
            candidates_dets = []
            for candidate_id in range(num_candidate):
                bbox_det = candidates_total[candidate_id]
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": candidate_id,
                                 "imgpath": img_path,
                                 "track_id": None,
                                 "bbox": bbox_det}
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]['keypoints']
                end_time_pose = time.time()
                total_time_POSE_ESTIMATOR += (end_time_pose - st_time_pose)
                bbox_det_next = xywh_to_x1y1x2y2(bbox_det)
                score = sum(keypoints[2::3]) / keypoints_number
                # 不知道为什么他这个pose的置信度会高于1
                if bbox_invalid(bbox_det_next) or score < 0.7:
                    filter_bbox_number += 1
                    continue
                candidate_det = bbox_det_next + [score]
                candidates_dets.append(candidate_det)
                keypoints_dict = {"img_id": img_id,
                                  "det_id": candidate_id,
                                  "imgpath": img_path,
                                  "track_id": None,
                                  "keypoints": keypoints}

                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            # 根据bbox的置信度来使用nms
            keep = py_cpu_nms(np.array(candidates_dets, dtype=np.float32), 0.5) if len(candidates_dets) > 0 else []

            candidates_total = np.array(candidates_total)[keep]
            t = bbox_dets_list.copy()
            k = keypoints_list.copy()
            # 筛选过后的
            bbox_dets_list = [t[i] for i in keep]
            keypoints_list = [k[i] for i in keep]
            """ Data association """
            cur_det_number = len(candidates_total)
            prev_det_number = len(bbox_list_prev_frame)
            if img_id == first_img_id or prev_det_number == 0:
                for det_id, bbox_det_dict in enumerate(bbox_dets_list):
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det_dict['det_id'] = det_id
                    keypoints_dict['det_id'] = det_id
                    track_id = next_id
                    bbox_det_dict['track_id'] = track_id
                    keypoints_dict['track_id'] = track_id
                    next_id = max(next_id, track_id)
                    next_id += 1
            else:
                scores = np.zeros((cur_det_number, prev_det_number))
                for det_id in range(cur_det_number):
                    bbox_det_dict = bbox_dets_list[det_id]
                    keypoints_dict = keypoints_list[det_id]
                    bbox_det = bbox_det_dict['bbox']
                    keypoints = keypoints_dict['keypoints']

                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(prev_det_number):
                        prev_bbox_det_dict = bbox_list_prev_frame[prev_det_id]
                        prev_keypoints_dict = keypoints_list_prev_frame[prev_det_id]
                        iou_score = iou(bbox_det, prev_bbox_det_dict['bbox'], xyxy=False)
                        if iou_score > 0.5:
                            st_time_pose = time.time()
                            pose_match_score = get_pose_matching_score(keypoints, prev_keypoints_dict["keypoints"],
                                                                       bbox_det_dict["bbox"],
                                                                       prev_bbox_det_dict["bbox"])
                            end_time_pose = time.time()
                            total_time_POSE_SIMILARITY += (end_time_pose - st_time_pose)
                            scores[det_id, prev_det_id] = iou_alpha1 * iou_score + pose_alpha1 * pose_match_score

                st_time_ass = time.time()
                bbox_dets_list, keypoints_list, now_next_id = bipartite_graph_matching(bbox_dets_list,
                                                                                       bbox_list_prev_frame, scores,
                                                                                       keypoints_list, next_id)
                end_time_ass = time.time()
                total_time_ASSOCIATE += (end_time_ass - st_time_ass)

                next_id = now_next_id

            if len(bbox_dets_list) == 0:
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": 0,
                                 "track_id": None,
                                 "imgpath": img_path,
                                 "bbox": [0, 0, 2, 2]}
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {"img_id": img_id,
                                  "det_id": 0,
                                  "track_id": None,
                                  "imgpath": img_path,
                                  "keypoints": []}
                keypoints_list.append(keypoints_dict)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        img_id += 1

    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list, bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path, classes, joint_pairs, joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")

        fps = 5  # 25 原来
        visualizer.make_video_from_images(img_paths, output_video_path, fps=fps, size=None, is_color=True,
                                          format="XVID")
Esempio n. 5
0
def light_track(pose_estimator,
                image_folder, output_json_path,
                visualize_folder, output_video_path):
    global total_time_POSE, total_time_DET, total_time_ALL, total_num_FRAMES, total_num_PERSONS
    global video_name
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    # process the frames sequentially
    keypoints_list = []
    bbox_dets_list = []
    # frame_prev = -1
    # frame_cur = 0
    img_id = -1
    next_id = 0
    bbox_dets_list_list = []
    keypoints_list_list = []

    flag_mandatory_keyframe = False

    img_paths = get_immediate_childfile_paths(image_folder)
    num_imgs = len(img_paths)
    total_num_FRAMES = num_imgs

    # 有gt的的bbox
    gt_bbox_img_id_list = [0]

    seed_mode = False

    while img_id < num_imgs - 1:
        img_id += 1
        img_path = img_paths[img_id]
        print("Current tracking: [image_id:{}]".format(img_id))
        frame_cur = img_id

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty

        if img_id in gt_bbox_img_id_list:
            # 当前帧是gt帧
            # 当做好数据处理后,要用gt来做,现在是伪gt
            ##  TODO 带数据弄好后 remove
            human_candidates = inference_yolov3(img_path)  # 拿到bbox
            num_dets = len(human_candidates)
            # 检测bbox的keypoints
            for det_id in range(num_dets):
                bbox_det = human_candidates[det_id]
                bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, enlarge_scale)
                bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                # update current frame bbox
                bbox_det_dict = {"img_id": img_id,
                                 "det_id": det_id,
                                 "imgpath": img_path,
                                 "track_id": det_id,
                                 "bbox": bbox_det}

                # keypoint检测,并记录时间
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]["keypoints"]
                end_time_pose = time.time()
                total_time_POSE += (end_time_pose - st_time_pose)

                keypoints_dict = {"img_id": img_id,
                                  "det_id": det_id,
                                  "imgpath": img_path,
                                  "track_id": det_id,
                                  "keypoints": keypoints}
                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            # assert len(bbox_dets_list) == 2
            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)
        else:
            # 当前帧非gt帧
            # perform detection at keyframes
            if seed_mode:
                logger.info("img_id:{},seed_mode".format(img_id))
                # 拿到上一帧的信息
                bbox_list_prev_frame = bbox_dets_list_list[img_id - 1].copy()
                keypoints_list_prev_frame = keypoints_list_list[img_id - 1].copy()
                num_prev_bbox = len(bbox_list_prev_frame)

                my_enlarge_scale = 0.3
                cur_image = cv2.imread(img_path)
                cur_image_name = os.path.basename(img_path).split('.')[0]
                cnt = 0
                for prev_det_id in range(num_prev_bbox):
                    prev_bbox_det = bbox_list_prev_frame[prev_det_id]["bbox"]  # xywh
                    track_id = bbox_list_prev_frame[prev_det_id]['track_id']
                    prev_enlarge_bbox_det = x1y1x2y2_to_xywh(
                        enlarge_bbox(xywh_to_x1y1x2y2(prev_bbox_det), my_enlarge_scale))
                    x1, x2, y1, y2 = max(0, int(prev_enlarge_bbox_det[0])), int(
                        prev_enlarge_bbox_det[0] + prev_enlarge_bbox_det[2]), \
                                     max(0, int(prev_enlarge_bbox_det[1])), int(
                        prev_enlarge_bbox_det[1] + prev_enlarge_bbox_det[3])
                    crop_image = cur_image[y1:y2, x1:x2].copy()
                    crop_image_folder_path = os.path.join(image_seed_crop_output_path, video_name, cur_image_name)
                    create_folder(crop_image_folder_path)
                    crop_image_path = os.path.join(crop_image_folder_path, "{:0>3d}".format(prev_det_id)) + '.jpg'
                    cv2.imwrite(crop_image_path, crop_image)
                    # 查看裁剪后的图片
                    human_candidates, confidence_scores = inference_yolov3_v1(crop_image_path)
                    logger.info(confidence_scores)
                    if len(human_candidates) > 0 and confidence_scores[0] > 0.90:
                        selected_bbox = human_candidates[0]
                        x1y1x2y2 = xywh_to_x1y1x2y2(selected_bbox)
                        # 左上角坐标
                        top_left_point_x, top_left_point_y = min(x1y1x2y2[0], x1y1x2y2[2]), min(x1y1x2y2[1],
                                                                                                x1y1x2y2[3])
                        best_bbox_det = [x1 + top_left_point_x, y1 + top_left_point_y, selected_bbox[2],
                                         selected_bbox[3]]

                        bbox_det_dict = {"img_id": img_id,
                                         "det_id": cnt,
                                         "imgpath": img_path,
                                         "track_id": track_id,
                                         "bbox": best_bbox_det}
                        crop_keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]["keypoints"]
                        keypoints_dict = {"img_id": img_id,
                                          "det_id": cnt,
                                          "imgpath": img_path,
                                          "track_id": track_id,
                                          "keypoints": crop_keypoints}
                        bbox_dets_list.append(bbox_det_dict)
                        keypoints_list.append(keypoints_dict)
                        cnt += 1
                        # for proposal_det_id in range(num_proposal_dets):
                        #     proposal_bbox_det = human_candidates[proposal_det_id]
                        #     proposal_bbox_det_dict = {"img_id": 1,
                        #                               "imgpath": crop_image_path, "bbox": proposal_bbox_det}
                        #     crop_keypoints = inference_keypoints(pose_estimator, proposal_bbox_det_dict)[0][
                        #         "keypoints"]  # keypoint_numer *(x,y,score)
                        #     keypoint_sum_score = 0
                        #     for i in range(len(crop_keypoints)):
                        #         if i % 3 == 2:
                        #             keypoint_sum_score = keypoint_sum_score + crop_keypoints[i]
                        #     logger.info("{},{}".format(proposal_det_id, keypoint_sum_score))
                        #
                        #     crop_bbox_image_path = os.path.join(crop_image_folder_path,
                        #                                         "{:0>3d}-{:0>3d}".format(prev_det_id,
                        #                                                                  proposal_det_id)) + '.jpg'
                        #     cv2.imwrite(crop_bbox_image_path, cropped_bbox_image)
                assert cnt == len(bbox_dets_list)
                print("Final save bbox number: {} ".format(len(bbox_dets_list)))
                print("image path:{}".format(img_path))
                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)
                seed_mode = False
            else:
                st_time_detection = time.time()
                # human_candidates  ( center_x,center_y,w,h)
                human_candidates, confidence_scores = inference_yolov3_v1(img_path)  # 拿到bbox

                end_time_detection = time.time()
                total_time_DET += (end_time_detection - st_time_detection)

                num_dets = len(human_candidates)
                print("Keyframe: {} detections".format(num_dets))

                # if nothing detected at this frame
                if num_dets <= 0:
                    ## TODO
                    break

                # 检测bbox的keypoints
                for det_id in range(num_dets):
                    bbox_det = human_candidates[det_id]
                    bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                    bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, enlarge_scale)
                    bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                    # update current frame bbox
                    bbox_det_dict = {"img_id": img_id,
                                     "det_id": det_id,
                                     "imgpath": img_path,
                                     "track_id": None,
                                     "bbox": bbox_det}

                    # keypoint检测,并记录时间
                    st_time_pose = time.time()
                    keypoints = inference_keypoints(pose_estimator, bbox_det_dict)[0]["keypoints"]
                    end_time_pose = time.time()
                    total_time_POSE += (end_time_pose - st_time_pose)

                    keypoints_dict = {"img_id": img_id,
                                      "det_id": det_id,
                                      "imgpath": img_path,
                                      "track_id": None,
                                      "keypoints": keypoints}
                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)

                # 拿到上一帧的信息
                bbox_list_prev_frame = bbox_dets_list_list[img_id - 1].copy()
                keypoints_list_prev_frame = keypoints_list_list[img_id - 1].copy()

                ############ 裁剪
                # if img_id in [34, 35, 36, 37, 38]:
                #     cnt = 0
                #     for bbox_info in bbox_list_prev_frame:
                #         bbox_det = bbox_info['bbox']
                #         image_path = bbox_info['imgpath']
                #         frame_name = os.path.basename(image_path)
                #         frame_name = frame_name.split('.')[0]
                #         video_name = os.path.basename(image_folder)
                #         image = cv2.imread(image_path)
                #         bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                #         bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, 0.1)
                #         bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)
                #         x1, y1, w, h = max(int(bbox_det[0]), 0), max(int(bbox_det[1]), 0), bbox_det[2], bbox_det[3]
                #         ### 得到裁剪后的图
                #         cropped_image = image[y1:(y1 + h), x1:(x1 + w)]
                #         create_folder(os.path.join(image_crop_output_path, video_name))
                #         cropped_image_path = os.path.join(image_crop_output_path, video_name,
                #                                           '{}-{:0>3d}.jpg'.format(frame_name, cnt))
                #         cv2.imwrite(cropped_image_path, cropped_image)
                #         ### 找bbox
                #         crop_human_candidates, _ = inference_yolov3_v1(cropped_image_path)
                #         for det_id in range(len(crop_human_candidates)):
                #             bbox_det = crop_human_candidates[det_id]
                #             ### 画bbox
                #             # cropped_bbox_image = visualizer.draw_bbox_from_python_data(cropped_image, bbox_det)
                #             cropped_bbox_image = cv2.rectangle(cropped_image.copy(), (int(bbox_det[0]), int(bbox_det[1])),
                #                                                (int(bbox_det[0] + bbox_det[2]),
                #                                                 int(bbox_det[1] + bbox_det[3])),
                #                                                (255, 0, 255), thickness=3)
                #             cropped_image_bbox_path = os.path.join(image_crop_output_path, video_name,
                #                                                    '{}-{:0>3d}-{:0>3d}.jpg'.format(frame_name, cnt, det_id))
                #             cv2.imwrite(cropped_image_bbox_path, cropped_bbox_image)
                #         cnt += 1

                ##############

                num_bbox_prev_frame = len(bbox_list_prev_frame)

                # 获取到三个指标的信息
                confidence_scores = np.array(confidence_scores)
                confidence_scores = confidence_scores[:, np.newaxis]
                pose_matching_scores = np.zeros([num_dets, num_bbox_prev_frame], dtype=float)
                iou_scores = np.ones([num_dets, num_bbox_prev_frame], dtype=float)
                prev_track_ids = []
                for bbox_prev_index in range(num_bbox_prev_frame):
                    # 上一帧中包含的trackIds
                    track_id = keypoints_list_prev_frame[bbox_prev_index]["track_id"]
                    prev_track_ids.append(track_id)
                for det_id in range(num_dets):
                    for bbox_prev_index in range(num_bbox_prev_frame):
                        keypoints_cur_frame = keypoints_list[det_id]["keypoints"]
                        bbox_cur_frame = bbox_dets_list[det_id]["bbox"]

                        keypoints_prev_frame = keypoints_list_prev_frame[bbox_prev_index]["keypoints"]
                        bbox_prev_frame = bbox_list_prev_frame[bbox_prev_index]["bbox"]
                        # get pose match score
                        pose_matching_scores[det_id, bbox_prev_index] = get_pose_matching_score(
                            keypoints_cur_frame,
                            keypoints_prev_frame,
                            bbox_cur_frame,
                            bbox_prev_frame)

                        # get bbox distance score
                        iou_scores[det_id, bbox_prev_index] = iou(bbox_cur_frame, bbox_prev_frame, xyxy=False)

                ###########################
                ## 根据指标来选择当前帧的框 ##
                ###########################
                bbox_dets_list, keypoints_list = select_bbox_by_criterion(bbox_dets_list, keypoints_list,
                                                                          confidence_scores,
                                                                          pose_matching_scores, iou_scores,
                                                                          prev_track_ids)
                num_save_bbox = len(bbox_dets_list)

                # 如果人数发生变化,该帧使用seed 模式
                if num_save_bbox < num_bbox_prev_frame:
                    seed_mode = True
                    img_id -= 1
                    continue
                print("Final save bbox number: {} ".format(len(bbox_dets_list)))
                print("image path:{}".format(img_path))
                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)

    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list, bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path, classes, joint_pairs, joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")

        fps = 5  # 25 原来
        visualizer.make_video_from_images(img_paths, output_video_path, fps=fps, size=None, is_color=True,
                                          format="XVID")
def light_track(pose_estimator, image_folder, output_json_path,
                visualize_folder, output_video_path, gt_info):
    global total_time_POSE, total_time_DET, total_time_ALL, total_num_FRAMES, total_num_PERSONS
    global video_name
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    next_id = 1
    bbox_dets_list_list = []
    keypoints_list_list = []

    num_imgs = len(gt_info)
    total_num_FRAMES = num_imgs
    first_img_id = 0

    start_from_labeled = False
    if start_from_labeled:
        first_img_id = find_first_labeled_opensvai_json(gt_info)

    img_id = first_img_id
    """  之后的数据关联如何做?

    TODO  
    相邻两帧之间的检测框匹配,权值使用 total score = IOU score + Pose similarity, 之后再使用匈牙利算法进行二分图的匹配。

    大致思路:
        1.先算出距离相似度和pose相似度,将两者相加作为 框之间的联系。
        2.过滤低置信度的框。
        3.二分图匹配。
    """
    iou_alpha1 = 1
    pose_alpha1 = 1
    while img_id < num_imgs:

        img_gt_info = gt_info[img_id]
        image_name, labeled, candidates_info = read_image_data_opensvai_json(
            img_gt_info)
        img_path = os.path.join(image_folder, image_name)

        bbox_dets_list = []  # keyframe: start from empty
        keypoints_list = []  # keyframe: start from empty

        if labeled and (img_id - first_img_id) % 5 == 0:
            logger.info("{},{}".format('gt', img_id))
            # gt frame
            num_dets = len(candidates_info)

            if img_id == first_img_id:
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                        candidates_info, det_id)
                    # first帧直接使用
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "bbox": bbox_det
                    }
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": track_id,
                        "keypoints": keypoints
                    }
                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)
                    next_id = max(next_id, track_id)
            else:  # Not First Frame
                bbox_list_prev_frame = bbox_dets_list_list[img_id -
                                                           first_img_id -
                                                           1].copy()
                keypoints_list_prev_frame = keypoints_list_list[img_id -
                                                                first_img_id -
                                                                1].copy()
                scores = np.empty((num_dets, len(keypoints_list_prev_frame)))
                for det_id in range(num_dets):
                    track_id, bbox_det, keypoints = get_candidate_info_opensvai_json(
                        candidates_info, det_id)
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": None,
                        "bbox": bbox_det
                    }
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "imgpath": img_path,
                        "track_id": None,
                        "keypoints": keypoints
                    }
                    # 计算当前帧的bbox和先前帧bboxes的分数
                    for prev_det_id in range(len(keypoints_list_prev_frame)):
                        prev_bbox_det_dict = bbox_dets_list_list[prev_det_id]
                        prev_keypoints_dict = keypoints_list_list[prev_det_id]
                        iou_score = iou(bbox_det,
                                        prev_bbox_det_dict['bbox'],
                                        xyxy=False)
                        if iou_score > 0.5:
                            pose_match_score = get_pose_matching_score(
                                keypoints, prev_keypoints_dict["keypoints"],
                                bbox_det_dict, prev_bbox_det_dict)
                            scores[
                                det_id,
                                prev_det_id] = iou_alpha1 * iou_score + pose_alpha1 * pose_match_score

                    bbox_dets_list.append(bbox_det_dict)
                    keypoints_list.append(keypoints_dict)

                bbox_dets_list = bipartite_graph_matching(
                    bbox_dets_list, keypoints_list_prev_frame, scores)
                # # 先用距离相似度来判断
                # track_id, match_index = get_track_id_SpatialConsistency(bbox_det, bbox_list_prev_frame)
                #
                # bbox_det_dict = {"img_id": img_id,
                #                  "det_id": det_id,
                #                  "imgpath": img_path,
                #                  "track_id": None,
                #                  "bbox": bbox_det}
                # keypoints_dict = {"img_id": img_id,
                #                   "det_id": det_id,
                #                   "imgpath": img_path,
                #                   "track_id": None,
                #                   "keypoints": keypoints}
                #
                # if track_id != -1:  # if candidate from prev frame matched, prevent it from matching another
                #     del bbox_list_prev_frame[match_index]
                #     del keypoints_list_prev_frame[match_index]
                #     bbox_det_dict["track_id"] = track_id
                #     keypoints_dict["track_id"] = track_id
                #     bbox_dets_list.append(bbox_det_dict)
                #     keypoints_list.append(keypoints_dict)
                #     # 找到匹配的了,找下一个
                #     continue
                #
                # # 再使用pose 相似度
                # track_id, match_index = get_track_id_SGCN(bbox_det, bbox_list_prev_frame, keypoints,
                #                                           keypoints_list_prev_frame)
                #
                # if track_id != -1:
                #     del bbox_list_prev_frame[match_index]
                #     del keypoints_list_prev_frame[match_index]
                #     bbox_det_dict["track_id"] = track_id
                #     keypoints_dict["track_id"] = track_id
                #     bbox_dets_list.append(bbox_det_dict)
                #     keypoints_list.append(keypoints_dict)
                #     # 找到匹配的了,找下一个
                #     continue
                #
                # # not find a match from  previous frame, then assign a new id
                # track_id = next_id
                # next_id += 1
                #
                # bbox_det_dict["track_id"] = track_id
                # keypoints_dict["track_id"] = track_id
                #
                # bbox_dets_list.append(bbox_det_dict)
                # keypoints_list.append(keypoints_dict)
            # 当前帧的最后的关联结果为空.
            if len(bbox_dets_list) == 0:
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        else:
            logger.info("{},{}".format('normal', img_id))
            ''' NOT GT Frame '''
            candidates_total = []
            candidates_from_detector = inference_yolov3(img_path)

            candidates_from_prev = []
            ''' 根据先前帧的信息补充框 '''
            if img_id > first_img_id:
                bbox_list_prev_frame = bbox_dets_list_list[img_id -
                                                           first_img_id -
                                                           1].copy()
                keypoints_list_prev_frame = keypoints_list_list[img_id -
                                                                first_img_id -
                                                                1].copy()
                num_prev_bbox = len(bbox_list_prev_frame)
                for prev_det_id in range(num_prev_bbox):
                    # obtain bbox position and track id
                    keypoints = keypoints_list_prev_frame[prev_det_id][
                        'keypoints']
                    bbox_det_next = get_bbox_from_keypoints(keypoints)
                    if bbox_invalid(bbox_det_next):
                        continue
                    # xywh
                    candidates_from_prev.append(bbox_det_next)
                # my_enlarge_scale = 0.2
                # cur_image = cv2.imread(img_path)
                # cur_image_name = os.path.basename(img_path).split('.')[0]
                # for prev_det_id in range(num_prev_bbox):
                #     prev_bbox_det = bbox_list_prev_frame[prev_det_id]["bbox"]  # xywh
                #     prev_enlarge_bbox_det = x1y1x2y2_to_xywh(
                #         enlarge_bbox(xywh_to_x1y1x2y2(prev_bbox_det), my_enlarge_scale))
                #     x1, x2, y1, y2 = max(0, int(prev_enlarge_bbox_det[0])), int(
                #         prev_enlarge_bbox_det[0] + prev_enlarge_bbox_det[2]), \
                #                      max(0, int(prev_enlarge_bbox_det[1])), int(
                #         prev_enlarge_bbox_det[1] + prev_enlarge_bbox_det[3])
                #     crop_image = cur_image[y1:y2, x1:x2].copy()
                #     crop_image_folder_path = os.path.join(image_seed_crop_output_path, video_name, cur_image_name)
                #     create_folder(crop_image_folder_path)
                #     crop_image_path = os.path.join(crop_image_folder_path, "{:0>3d}".format(prev_det_id)) + '.jpg'
                #     cv2.imwrite(crop_image_path, crop_image)
                #     # 查看裁剪后的图片
                #     human_candidates, confidence_scores = inference_yolov3_v1(crop_image_path)
                #     # logger.info(confidence_scores)
                #     if len(human_candidates) > 0 and confidence_scores[0] > 0.7:
                #         selected_bbox = human_candidates[0]
                #         x1y1x2y2 = xywh_to_x1y1x2y2(selected_bbox)
                #         # 左上角坐标
                #         top_left_point_x, top_left_point_y = min(x1y1x2y2[0], x1y1x2y2[2]), min(x1y1x2y2[1],
                #                                                                                 x1y1x2y2[3])
                #         bbox_det_in_original_pic = [x1 + top_left_point_x, y1 + top_left_point_y, selected_bbox[2],
                #                                     selected_bbox[3]]
                #         candidates_from_prev.append(bbox_det_in_original_pic)  # xywh
            ''' 拿到本帧全部的候选框 '''
            candidates_total = candidates_from_detector + candidates_from_prev
            num_candidate = len(candidates_total)
            ''' 使用关节点的置信度来作为bbox的置信度 '''
            candidates_dets = []
            for candidate_id in range(num_candidate):
                bbox_det = candidates_total[candidate_id]
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "bbox": bbox_det
                }
                keypoints = inference_keypoints(pose_estimator,
                                                bbox_det_dict)[0]['keypoints']

                bbox_det_next = xywh_to_x1y1x2y2(bbox_det)
                score = sum(keypoints[2::3]) / 25
                if bbox_invalid(bbox_det_next) or score < 0.5:
                    continue
                candidate_det = bbox_det_next + [score]
                candidates_dets.append(candidate_det)
                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": candidate_id,
                    "imgpath": img_path,
                    "track_id": None,
                    "keypoints": keypoints
                }

                bbox_dets_list.append(bbox_det_dict)
                keypoints_list.append(keypoints_dict)
            ''' 根据bbox的置信度来使用nms '''
            if len(candidates_dets) == 0:
                img_id += 1
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)
                continue
            keep = py_cpu_nms(np.array(candidates_dets, dtype=np.float32), 0.5)
            candidates_total = np.array(candidates_total)[keep]
            t = bbox_dets_list.copy()
            k = keypoints_list.copy()
            bbox_dets_list = []
            keypoints_list = []
            bbox_dets_list = [t[i] for i in keep]
            keypoints_list = [k[i] for i in keep]
            """ Data association """
            # 检测bbox的keypoints
            num_candidate = len(candidates_total)
            for candidate_id in range(num_candidate):
                bbox_det_dict = bbox_dets_list[candidate_id]
                keypoints_dict = keypoints_list[candidate_id]
                bbox_det = bbox_det_dict['bbox']
                keypoints = keypoints_dict['keypoints']
                # Data association
                # 先用距离相似度来判断
                if img_id > first_img_id:
                    track_id, match_index = get_track_id_SpatialConsistency(
                        bbox_det, bbox_list_prev_frame)
                    if track_id != -1:  # if candidate from prev frame matched, prevent it from matching another
                        del bbox_list_prev_frame[match_index]
                        del keypoints_list_prev_frame[match_index]
                        bbox_det_dict["track_id"] = track_id
                        bbox_det_dict["det_id"] = candidate_id
                        keypoints_dict["track_id"] = track_id
                        keypoints_dict["det_id"] = candidate_id
                        keypoints = inference_keypoints(
                            pose_estimator, bbox_det_dict)[0]['keypoints']
                        keypoints_dict['keypoints'] = keypoints
                        # 找到匹配的了,找下一个
                        continue

                    # 再使用pose 相似度
                    track_id, match_index = get_track_id_SGCN(
                        bbox_det, bbox_list_prev_frame, keypoints,
                        keypoints_list_prev_frame)

                    if track_id != -1:
                        del bbox_list_prev_frame[match_index]
                        del keypoints_list_prev_frame[match_index]
                        bbox_det_dict["track_id"] = track_id
                        bbox_det_dict["det_id"] = candidate_id
                        keypoints_dict["track_id"] = track_id
                        keypoints_dict["det_id"] = candidate_id
                        keypoints = inference_keypoints(
                            pose_estimator, bbox_det_dict)[0]['keypoints']
                        keypoints_dict['keypoints'] = keypoints
                        # 找到匹配的了,找下一个
                        continue

                # not find a match from  previous frame, then assign a new id
                track_id = next_id
                next_id += 1

                bbox_det_dict["track_id"] = track_id
                bbox_det_dict["det_id"] = candidate_id
                keypoints_dict["track_id"] = track_id
                keypoints_dict["det_id"] = candidate_id

                keypoints = inference_keypoints(pose_estimator,
                                                bbox_det_dict)[0]['keypoints']
                keypoints_dict['keypoints'] = keypoints

            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)

        img_id += 1
    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list,
                                          bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path,
                                               classes,
                                               joint_pairs,
                                               joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")

        fps = 5  # 25 原来
        visualizer.make_video_from_images(img_paths,
                                          output_video_path,
                                          fps=fps,
                                          size=None,
                                          is_color=True,
                                          format="XVID")
Esempio n. 7
0
def light_track(pose_estimator, image_folder, output_json_path,
                visualize_folder, output_video_path):
    global total_time_POSE, total_time_DET, total_time_ALL, total_num_FRAMES, total_num_PERSONS
    ''' 1. statistics: get total time for lighttrack processing'''
    st_time_total = time.time()

    # process the frames sequentially
    keypoints_list = []
    bbox_dets_list = []
    frame_prev = -1
    frame_cur = 0
    img_id = -1
    next_id = 0
    bbox_dets_list_list = []
    keypoints_list_list = []

    flag_mandatory_keyframe = False

    img_paths = get_immediate_childfile_paths(image_folder)
    num_imgs = len(img_paths)
    total_num_FRAMES = num_imgs

    while img_id < num_imgs - 1:
        img_id += 1
        img_path = img_paths[img_id]
        print("Current tracking: [image_id:{}]".format(img_id))

        frame_cur = img_id
        if (frame_cur == frame_prev):
            frame_prev -= 1
        ''' KEYFRAME: loading results from other modules '''
        if is_keyframe(img_id, keyframe_interval) or flag_mandatory_keyframe:
            flag_mandatory_keyframe = False
            bbox_dets_list = []  # keyframe: start from empty
            keypoints_list = []  # keyframe: start from empty

            # perform detection at keyframes
            st_time_detection = time.time()
            human_candidates = inference_yolov3(img_path)
            end_time_detection = time.time()
            total_time_DET += (end_time_detection - st_time_detection)

            num_dets = len(human_candidates)
            print("Keyframe: {} detections".format(num_dets))

            # if nothing detected at keyframe, regard next frame as keyframe because there is nothing to track
            if num_dets <= 0:
                # add empty result
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "bbox": [0, 0, 2, 2]
                }
                bbox_dets_list.append(bbox_det_dict)

                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": 0,
                    "track_id": None,
                    "imgpath": img_path,
                    "keypoints": []
                }
                keypoints_list.append(keypoints_dict)

                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)

                flag_mandatory_keyframe = True
                continue
            ''' 2. statistics: get total number of detected persons '''
            total_num_PERSONS += num_dets

            if img_id > 0:  # First frame does not have previous frame
                # bbox_list_prev_frame 是一个list,每个item都是dict,有 img_id-int,det_id-int,track_id-int,imgpath-str,bbox-list 几个属性,bbox是个list,包含4个值。
                bbox_list_prev_frame = bbox_dets_list_list[img_id - 1].copy()
                keypoints_list_prev_frame = keypoints_list_list[img_id -
                                                                1].copy()

            # For each candidate, perform pose estimation and data association based on Spatial Consistency (SC)
            for det_id in range(num_dets):
                # obtain bbox position and track id
                bbox_det = human_candidates[det_id]

                # enlarge bbox by 20% with same center position
                bbox_x1y1x2y2 = xywh_to_x1y1x2y2(bbox_det)
                bbox_in_xywh = enlarge_bbox(bbox_x1y1x2y2, enlarge_scale)
                bbox_det = x1y1x2y2_to_xywh(bbox_in_xywh)

                # Keyframe: use provided bbox
                if bbox_invalid(bbox_det):
                    track_id = None  # this id means null
                    keypoints = []
                    bbox_det = [0, 0, 2, 2]
                    # update current frame bbox
                    bbox_det_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "track_id": track_id,
                        "imgpath": img_path,
                        "bbox": bbox_det
                    }
                    bbox_dets_list.append(bbox_det_dict)
                    # update current frame keypoints
                    keypoints_dict = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "track_id": track_id,
                        "imgpath": img_path,
                        "keypoints": keypoints
                    }
                    keypoints_list.append(keypoints_dict)
                    continue

                # update current frame bbox
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "imgpath": img_path,
                    "bbox": bbox_det
                }

                # obtain keypoints for each bbox position in the keyframe
                st_time_pose = time.time()
                keypoints = inference_keypoints(pose_estimator,
                                                bbox_det_dict)[0]["keypoints"]
                end_time_pose = time.time()
                total_time_POSE += (end_time_pose - st_time_pose)

                if img_id == 0:  # First frame, all ids are assigned automatically
                    track_id = next_id
                    next_id += 1
                else:
                    track_id, match_index = get_track_id_SpatialConsistency(
                        bbox_det, bbox_list_prev_frame)

                    if track_id != -1:  # if candidate from prev frame matched, prevent it from matching another
                        del bbox_list_prev_frame[match_index]
                        del keypoints_list_prev_frame[match_index]

                # update current frame bbox
                bbox_det_dict = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "track_id": track_id,
                    "imgpath": img_path,
                    "bbox": bbox_det
                }
                bbox_dets_list.append(bbox_det_dict)

                # update current frame keypoints
                keypoints_dict = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "track_id": track_id,
                    "imgpath": img_path,
                    "keypoints": keypoints
                }
                keypoints_list.append(keypoints_dict)

            # For candidate that is not assopciated yet, perform data association based on Pose Similarity (SGCN)
            for det_id in range(num_dets):
                bbox_det_dict = bbox_dets_list[det_id]
                keypoints_dict = keypoints_list[det_id]
                assert (det_id == bbox_det_dict["det_id"])
                assert (det_id == keypoints_dict["det_id"])

                if bbox_det_dict[
                        "track_id"] == -1:  # this id means matching not found yet
                    track_id, match_index = get_track_id_SGCN(
                        bbox_det_dict["bbox"], bbox_list_prev_frame,
                        keypoints_dict["keypoints"], keypoints_list_prev_frame)

                    if track_id != -1:  # if candidate from prev frame matched, prevent it from matching another
                        del bbox_list_prev_frame[match_index]
                        del keypoints_list_prev_frame[match_index]
                        bbox_det_dict["track_id"] = track_id
                        keypoints_dict["track_id"] = track_id

                    # if still can not find a match from previous frame, then assign a new id
                    if track_id == -1 and not bbox_invalid(
                            bbox_det_dict["bbox"]):
                        bbox_det_dict["track_id"] = next_id
                        keypoints_dict["track_id"] = next_id
                        next_id += 1

            # update frame
            bbox_dets_list_list.append(bbox_dets_list)
            keypoints_list_list.append(keypoints_list)
            frame_prev = frame_cur
            print("This is KeyFrame")
        else:
            ''' NOT KEYFRAME: multi-target pose tracking '''
            bbox_dets_list_next = []
            keypoints_list_next = []
            print("This is NOT KeyFrame")
            num_dets = len(keypoints_list)
            total_num_PERSONS += num_dets

            if num_dets == 0:
                flag_mandatory_keyframe = True

            for det_id in range(num_dets):
                keypoints = keypoints_list[det_id]["keypoints"]

                # for non-keyframes, the tracked target preserves its track_id
                track_id = keypoints_list[det_id]["track_id"]

                # next frame bbox
                bbox_det_next = get_bbox_from_keypoints(keypoints)
                if bbox_det_next[2] == 0 or bbox_det_next[3] == 0:
                    bbox_det_next = [0, 0, 2, 2]
                    total_num_PERSONS -= 1
                assert (bbox_det_next[2] != 0 and bbox_det_next[3] != 0
                        )  # width and height must not be zero
                bbox_det_dict_next = {
                    "img_id": img_id,
                    "det_id": det_id,
                    "track_id": track_id,
                    "imgpath": img_path,
                    "bbox": bbox_det_next
                }

                # next frame keypoints
                st_time_pose = time.time()
                keypoints_next = inference_keypoints(
                    pose_estimator, bbox_det_dict_next)[0]["keypoints"]
                end_time_pose = time.time()
                total_time_POSE += (end_time_pose - st_time_pose)
                # print("time for pose estimation: ", (end_time_pose - st_time_pose))

                # check whether the target is lost
                target_lost = is_target_lost(keypoints_next)

                if target_lost is False:
                    bbox_dets_list_next.append(bbox_det_dict_next)
                    keypoints_dict_next = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "track_id": track_id,
                        "imgpath": img_path,
                        "keypoints": keypoints_next
                    }
                    keypoints_list_next.append(keypoints_dict_next)

                else:
                    # remove this bbox, do not register its keypoints
                    bbox_det_dict_next = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "track_id": None,
                        "imgpath": img_path,
                        "bbox": [0, 0, 2, 2]
                    }
                    bbox_dets_list_next.append(bbox_det_dict_next)

                    keypoints_null = 45 * [0]
                    keypoints_dict_next = {
                        "img_id": img_id,
                        "det_id": det_id,
                        "track_id": None,
                        "imgpath": img_path,
                        "keypoints": []
                    }
                    keypoints_list_next.append(keypoints_dict_next)
                    print(
                        "Target lost. Process this frame again as keyframe. \n\n\n"
                    )
                    flag_mandatory_keyframe = True

                    total_num_PERSONS -= 1
                    ## Re-process this frame by treating it as a keyframe
                    if img_id not in [0]:
                        img_id -= 1
                    break

            # update frame
            if flag_mandatory_keyframe is False:
                bbox_dets_list = bbox_dets_list_next
                keypoints_list = keypoints_list_next
                bbox_dets_list_list.append(bbox_dets_list)
                keypoints_list_list.append(keypoints_list)
                frame_prev = frame_cur
    ''' 1. statistics: get total time for lighttrack processing'''
    end_time_total = time.time()
    total_time_ALL += (end_time_total - st_time_total)

    # convert results into openSVAI format
    print("Exporting Results in openSVAI Standard Json Format...")
    poses_standard = pose_to_standard_mot(keypoints_list_list,
                                          bbox_dets_list_list)
    # json_str = python_to_json(poses_standard)
    # print(json_str)

    # output json file
    pose_json_folder, _ = get_parent_folder_from_path(output_json_path)
    create_folder(pose_json_folder)
    write_json_to_file(poses_standard, output_json_path)
    print("Json Export Finished!")

    # visualization
    if flag_visualize is True:
        print("Visualizing Pose Tracking Results...")
        create_folder(visualize_folder)
        visualizer.show_all_from_standard_json(output_json_path,
                                               classes,
                                               joint_pairs,
                                               joint_names,
                                               image_folder,
                                               visualize_folder,
                                               flag_track=True)
        print("Visualization Finished!")

        img_paths = get_immediate_childfile_paths(visualize_folder)
        avg_fps = total_num_FRAMES / total_time_ALL
        # make_video_from_images(img_paths, output_video_path, fps=avg_fps, size=None, is_color=True, format="XVID")
        visualizer.make_video_from_images(img_paths,
                                          output_video_path,
                                          fps=25,
                                          size=None,
                                          is_color=True,
                                          format="XVID")