def test_draw_keypoints_vanilla(): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() result = utils.draw_keypoints( img, keypoints, colors="red", connectivity=[ (0, 1), ], ) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") if not os.path.exists(path): res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) res.save(path) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) assert_equal(result, expected) # Check that keypoints are not modified inplace assert_equal(keypoints, keypoints_cp) # Check that image is not modified in place assert_equal(img, img_cp)
def test_draw_keypoints_errors(): h, w = 10, 10 img = torch.full((3, 100, 100), 0, dtype=torch.uint8) with pytest.raises(TypeError, match="The image must be a tensor"): utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints) with pytest.raises(ValueError, match="The image dtype must be"): img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64) utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints) with pytest.raises(ValueError, match="Pass individual images, not batches"): batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) utils.draw_keypoints(image=batch, keypoints=keypoints) with pytest.raises(ValueError, match="Pass an RGB image"): one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) utils.draw_keypoints(image=one_channel, keypoints=keypoints) with pytest.raises(ValueError, match="keypoints must be of shape"): invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float) utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() result = utils.draw_keypoints( img, keypoints, colors=colors, connectivity=[ (0, 1), ], ) assert result.size(0) == 3 assert_equal(keypoints, keypoints_cp) assert_equal(img, img_cp)
detect_threshold = 0.75 idx = torch.where(scores > detect_threshold) keypoints = kpts[idx] print(keypoints) ##################################### # Great, now we have the keypoints corresponding to the person. # Each keypoint is represented by x, y coordinates and the visibility. # We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. # Note that the utility expects uint8 images. from torchvision.utils import draw_keypoints res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) show(res) ##################################### # As we see the keypoints appear as colored circles over the image. # The coco keypoints for a person are ordered and represent the following list.\ coco_keypoints = [ "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle", ] ##################################### # What if we are interested in joining the keypoints?