Ejemplo n.º 1
0
def test_draw_keypoints_vanilla():
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors="red",
        connectivity=[
            (0, 1),
        ],
    )
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                        "fakedata", "draw_keypoint_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check that keypoints are not modified inplace
    assert_equal(keypoints, keypoints_cp)
    # Check that image is not modified in place
    assert_equal(img, img_cp)
Ejemplo n.º 2
0
def test_draw_keypoints_errors():
    h, w = 10, 10
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
        utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
    with pytest.raises(ValueError,
                       match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=batch, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=one_channel, keypoints=keypoints)
    with pytest.raises(ValueError, match="keypoints must be of shape"):
        invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]],
                                         dtype=torch.float)
        utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
Ejemplo n.º 3
0
def test_draw_keypoints_colored(colors):
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors=colors,
        connectivity=[
            (0, 1),
        ],
    )
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)
Ejemplo n.º 4
0
detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]

print(keypoints)

#####################################
# Great, now we have the keypoints corresponding to the person.
# Each keypoint is represented by x, y coordinates and the visibility.
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
# Note that the utility expects uint8 images.

from torchvision.utils import draw_keypoints

res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)

#####################################
# As we see the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\

coco_keypoints = [
    "nose", "left_eye", "right_eye", "left_ear", "right_ear",
    "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
    "left_wrist", "right_wrist", "left_hip", "right_hip",
    "left_knee", "right_knee", "left_ankle", "right_ankle",
]

#####################################
# What if we are interested in joining the keypoints?