示例#1
0
        def graph_fn():
            class_onehot = tf.constant(
                [[0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 1]],
                dtype=tf.float32)
            keypoints = tf.constant(
                [[0.1, float('nan'), 0.2, 0.0], [0.0, 0.0, 0.1, 0.9],
                 [3.2, 4.3, float('nan'), 0.2]],
                dtype=tf.float32)
            keypoint_coordinates = tf.stack([keypoints, keypoints], axis=2)
            mask, keypoints_nan_to_zeros = ta_utils.get_valid_keypoint_mask_for_class(
                keypoint_coordinates=keypoint_coordinates,
                class_id=2,
                class_onehot=class_onehot,
                keypoint_indices=[1, 2])

            return mask, keypoints_nan_to_zeros
示例#2
0
 def test_get_valid_keypoints_mask(self):
   class_onehot = tf.constant(
       [[0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 1]], dtype=tf.float32)
   keypoints = tf.constant(
       [[0.1, float('nan'), 0.2, 0.0],
        [0.0, 0.0, 0.1, 0.9],
        [3.2, 4.3, float('nan'), 0.2]],
       dtype=tf.float32)
   mask, keypoints_nan_to_zeros = ta_utils.get_valid_keypoint_mask_for_class(
       keypoint_coordinate=keypoints,
       class_id=2,
       class_onehot=class_onehot,
       keypoint_indices=[1, 2])
   expected_mask = np.array([[0, 1], [0, 0], [1, 0]])
   expected_keypoints = tf.constant([[0.0, 0.2], [0.0, 0.1], [4.3, 0.0]],
                                    dtype=tf.float32)
   np.testing.assert_array_equal(mask.numpy(), expected_mask)
   np.testing.assert_array_equal(keypoints_nan_to_zeros.numpy(),
                                 expected_keypoints)