def graph_fn(): class_onehot = tf.constant( [[0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 1]], dtype=tf.float32) keypoints = tf.constant( [[0.1, float('nan'), 0.2, 0.0], [0.0, 0.0, 0.1, 0.9], [3.2, 4.3, float('nan'), 0.2]], dtype=tf.float32) keypoint_coordinates = tf.stack([keypoints, keypoints], axis=2) mask, keypoints_nan_to_zeros = ta_utils.get_valid_keypoint_mask_for_class( keypoint_coordinates=keypoint_coordinates, class_id=2, class_onehot=class_onehot, keypoint_indices=[1, 2]) return mask, keypoints_nan_to_zeros
def test_get_valid_keypoints_mask(self): class_onehot = tf.constant( [[0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 1]], dtype=tf.float32) keypoints = tf.constant( [[0.1, float('nan'), 0.2, 0.0], [0.0, 0.0, 0.1, 0.9], [3.2, 4.3, float('nan'), 0.2]], dtype=tf.float32) mask, keypoints_nan_to_zeros = ta_utils.get_valid_keypoint_mask_for_class( keypoint_coordinate=keypoints, class_id=2, class_onehot=class_onehot, keypoint_indices=[1, 2]) expected_mask = np.array([[0, 1], [0, 0], [1, 0]]) expected_keypoints = tf.constant([[0.0, 0.2], [0.0, 0.1], [4.3, 0.0]], dtype=tf.float32) np.testing.assert_array_equal(mask.numpy(), expected_mask) np.testing.assert_array_equal(keypoints_nan_to_zeros.numpy(), expected_keypoints)