def process_heatmap(heatmap, image, scale, class_names, skeleton_lines):
    start = time.time()
    # parse out predicted keypoint from heatmap
    keypoints = post_process_heatmap(heatmap)

    # rescale keypoints back to origin image size
    keypoints_dict = dict()
    for i, keypoint in enumerate(keypoints):
        keypoints_dict[class_names[i]] = (keypoint[0] * scale[0] * 4,
                                          keypoint[1] * scale[1] * 4,
                                          keypoint[2])

    end = time.time()
    print("PostProcess time: {:.8f}ms".format((end - start) * 1000))

    print('Keypoints detection result:')
    for keypoint in keypoints_dict.items():
        print(keypoint)

    # draw the keypoint skeleton on image
    image_array = np.array(image, dtype='uint8')
    image_array = render_skeleton(image_array, keypoints_dict, skeleton_lines)

    Image.fromarray(image_array).show()
    return
예제 #2
0
    def predict(self, image_data):
        # get final predict heatmap
        heatmap = self.hourglass_model.predict(image_data)[-1]
        heatmap = heatmap[0]
        # parse out predicted keypoint from heatmap
        keypoints = post_process_heatmap(heatmap)

        return keypoints
def get_predicted_kp_from_htmap(heatmap, meta, outres):
    # nms to get location
    kplst = post_process_heatmap(heatmap)
    kps = np.array(kplst)

    # use meta information to transform back to original image
    mkps = copy.copy(kps)
    for i in range(kps.shape[0]):
        mkps[i, 0:2] = transform(kps[i],
                                 meta['center'],
                                 meta['scale'],
                                 res=outres,
                                 invert=1,
                                 rot=0)

    return mkps
def eval_PCK(model, model_format, eval_dataset, class_names, score_threshold, normalize, conf_threshold, save_result=False, skeleton_lines=None):
    if model_format == 'MNN':
        #MNN inference engine need create session
        session = model.createSession()

    succeed_dict = {class_name: 0 for class_name in class_names}
    fail_dict = {class_name: 0 for class_name in class_names}
    accuracy_dict = {class_name: 0. for class_name in class_names}

    # init output list for coco result json generation
    # coco keypoints result is a list of following format dict:
    # {
    #  "image_id": int,
    #  "category_id": int,
    #  "keypoints": [x1,y1,v1,...,xk,yk,vk],
    #  "score": float
    # }
    #
    output_list = []

    count = 0
    batch_size = 1
    pbar = tqdm(total=eval_dataset.get_dataset_size(), desc='Eval model')
    for image_data, gt_heatmap, metainfo in eval_dataset.generator(batch_size, 8, sigma=1, is_shuffle=False, with_meta=True):
        # fetch validation data from generator, which will crop out single person area, resize to input_size and normalize image
        count += batch_size
        if count > eval_dataset.get_dataset_size():
            break

        # support of tflite model
        if model_format == 'TFLITE':
            heatmap = hourglass_predict_tflite(model, image_data)
        # support of MNN model
        elif model_format == 'MNN':
            heatmap = hourglass_predict_mnn(model, session, image_data)
        # support of TF 1.x frozen pb model
        elif model_format == 'PB':
            heatmap = hourglass_predict_pb(model, image_data)
        # support of ONNX model
        elif model_format == 'ONNX':
            heatmap = hourglass_predict_onnx(model, image_data)
        # normal keras h5 model
        elif model_format == 'H5':
            heatmap = hourglass_predict_keras(model, image_data)
        else:
            raise ValueError('invalid model format')

        heatmap_size = heatmap.shape[0:2]

        # get predict keypoints from heatmap
        pred_keypoints = post_process_heatmap(heatmap, conf_threshold)
        pred_keypoints = np.array(pred_keypoints)

        # get ground truth keypoints (transformed)
        metainfo = metainfo[0]
        gt_keypoints = metainfo['tpts']

        # calculate succeed & failed keypoints for prediction
        result_list = keypoint_accuracy(pred_keypoints, gt_keypoints, score_threshold, normalize)

        for i, class_name in enumerate(class_names):
            if result_list[i] == 0:
                fail_dict[class_name] = fail_dict[class_name] + 1
            elif result_list[i] == 1:
                succeed_dict[class_name] = succeed_dict[class_name] + 1

        # revert predict keypoints back to origin image size
        reverted_pred_keypoints = revert_keypoints(pred_keypoints, metainfo, heatmap_size)

        # get coco result dict with predict keypoints and image info
        result_dict = get_result_dict(reverted_pred_keypoints, metainfo)
        # add result dict to output list
        output_list.append(result_dict)

        if save_result:
            # render keypoints skeleton on image and save result
            save_keypoints_detection(reverted_pred_keypoints, metainfo, class_names, skeleton_lines)
        pbar.update(batch_size)
    pbar.close()

    # save to coco result json
    touchdir('result')
    json_fp = open(os.path.join('result','keypoints_result.json'), 'w')
    json_str = json.dumps(output_list)
    json_fp.write(json_str)
    json_fp.close()

    # calculate accuracy for each class
    for i, class_name in enumerate(class_names):
        accuracy_dict[class_name] = succeed_dict[class_name] * 1.0 / (succeed_dict[class_name] + fail_dict[class_name])

    #get PCK accuracy from succeed & failed keypoints
    total_succeed = np.sum(list(succeed_dict.values()))
    total_fail = np.sum(list(fail_dict.values()))
    total_accuracy = total_succeed * 1.0 / (total_fail + total_succeed)

    if save_result:
        '''
         Draw PCK plot
        '''
        window_title = "PCK evaluation"
        plot_title = "PCK@{0} score = {1:.2f}%".format(score_threshold, total_accuracy)
        x_label = "Accuracy"
        output_path = os.path.join('result','PCK.jpg')
        draw_plot_func(accuracy_dict, len(accuracy_dict), window_title, plot_title, x_label, output_path, to_show=False, plot_color='royalblue', true_p_bar='')

    return total_accuracy, accuracy_dict