Esempio n. 1
0
def test_keypoint_detection():
    """Verify just image is changed."""
    input_image = PIL.Image.new("RGB", size=(100, 200))
    joints = np.zeros(shape=(17, 3), dtype=np.int)
    joints[0] = [30, 30, 1]

    result_image = visualize_keypoint_detection(np.array(input_image), joints)

    assert not np.all(np.array(input_image) == np.array(result_image))
Esempio n. 2
0
def run_keypoint_detection(config):
    global nn
    camera_width = 320
    camera_height = 240
    window_name = "Keypoint Detection Demo"

    input_width = config.IMAGE_SIZE[1]
    input_height = config.IMAGE_SIZE[0]

    vc = init_camera(camera_width, camera_height)

    pool = Pool(processes=1, initializer=nn.init)
    result = None
    fps = 1.0

    q_save = Queue()
    q_show = Queue()

    grabbed, camera_img = vc.read()

    q_show.put(camera_img.copy())
    input_img = camera_img.copy()

    while True:
        m1 = MyTime("1 loop of while(1) of main()")
        pool_result = pool.apply_async(_run_inference, (input_img, ))
        is_first = True
        while True:
            grabbed, camera_img = vc.read()
            if is_first:
                input_img = camera_img.copy()
                is_first = False
            q_save.put(camera_img.copy())
            if not q_show.empty():
                window_img = q_show.get()
                drawed_img = window_img
                if result is not None:

                    drawed_img = visualize_keypoint_detection(
                        window_img, result[0], (input_height, input_width))
                    drawed_img = add_fps(drawed_img, fps)

                cv2.imshow(window_name, drawed_img)
                key = cv2.waitKey(2)  # Wait for 2ms
                # TODO(yang): Consider using another key for abort.
                if key == 27:  # ESC to quit
                    return

            # TODO(yang): Busy loop is not efficient here. Improve it and change them in other tasks.
            if pool_result.ready():
                break

        q_show = clear_queue(q_show)
        q_save, q_show = swap_queue(q_save, q_show)
        result, fps = pool_result.get()
        m1.show()
Esempio n. 3
0
def show_keypoint_detection(img, result, fps, window_height, window_width,
                            config):
    window_img = resize(img, size=[window_height, window_width])

    window_img = visualize_keypoint_detection(window_img, result[0],
                                              (input_height, input_width))
    window_img = add_fps(window_img, fps)

    window_name = "Keypoint Detection Demo"
    cv2.imshow(window_name, window_img)
Esempio n. 4
0
    def py_visualize_output(images, heatmaps, stride=2):
        """Visualize pose estimation, it is mainly used for visualization in training time.

        Args:
            images: a numpy array of shape (batch_size, height, width, 3).
            heatmaps: a numpy array of shape (batch_size, height, width, num_joints).
            stride: int, stride = image_height / heatmap_height.

        Returns:
            drawed_images: a numpy array of shape (batch_size, height, width, 3).

        """
        drawed_images = np.uint8(images * 255.0)

        for i in range(images.shape[0]):
            joints = gaussian_heatmap_to_joints(heatmaps[i], stride=stride)
            drawed_images[i] = visualize_keypoint_detection(
                drawed_images[i], joints)
        return drawed_images
Esempio n. 5
0
    def _keypoint_detection(self, result_json, raw_images, image_files):
        outputs = json.loads(result_json)
        results = outputs["results"]
        filename_images = []

        for i, (result, raw_image,
                image_file) in enumerate(zip(results, raw_images,
                                             image_files)):
            base, _ = os.path.splitext(os.path.basename(image_file))
            file_name = "{}.png".format(base)
            joints_list = result["prediction"]["joints"]
            number_joints = len(joints_list) // 3
            joints = np.zeros(shape=(number_joints, 3), dtype=np.float)
            for j in range(number_joints):
                joints[j, 0] = joints_list[j * 3]
                joints[j, 1] = joints_list[j * 3 + 1]
                joints[j, 2] = joints_list[j * 3 + 2]
            image = visualize_keypoint_detection(raw_image, joints)
            filename_images.append((file_name, PIL.Image.fromarray(image)))

        return filename_images