def test_drawing_spec(self):
   landmark_list = text_format.Parse(
       'landmark {x: 0.1 y: 0.1}'
       'landmark {x: 0.8 y: 0.8}', landmark_pb2.NormalizedLandmarkList())
   image = np.zeros((100, 100, 3), np.uint8)
   landmark_drawing_spec = drawing_utils.DrawingSpec(
       color=(0, 0, 255), thickness=5)
   connection_drawing_spec = drawing_utils.DrawingSpec(
       color=(255, 0, 0), thickness=3)
   expected_result = np.copy(image)
   start_point = (10, 10)
   end_point = (80, 80)
   cv2.line(expected_result, start_point, end_point,
            connection_drawing_spec.color, connection_drawing_spec.thickness)
   cv2.circle(expected_result, start_point,
              landmark_drawing_spec.circle_radius, landmark_drawing_spec.color,
              landmark_drawing_spec.thickness)
   cv2.circle(expected_result, end_point, landmark_drawing_spec.circle_radius,
              landmark_drawing_spec.color, landmark_drawing_spec.thickness)
   drawing_utils.draw_landmarks(
       image=image,
       landmark_list=landmark_list,
       connections=[(0, 1)],
       landmark_drawing_spec=landmark_drawing_spec,
       connection_drawing_spec=connection_drawing_spec)
   np.testing.assert_array_equal(image, expected_result)
예제 #2
0
 def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int):
     drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1)
     for face_landmarks in results.multi_face_landmarks:
         mp_drawing.draw_landmarks(image=frame,
                                   landmark_list=face_landmarks,
                                   landmark_drawing_spec=drawing_spec)
     path = os.path.join(
         tempfile.gettempdir(),
         self.id().split('.')[-1] + '_frame_{}.png'.format(idx))
     cv2.imwrite(path, frame)
 def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int):
     drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1)
     mp_drawing.draw_landmarks(image=frame,
                               landmark_list=results.face_landmarks,
                               landmark_drawing_spec=drawing_spec)
     mp_drawing.draw_landmarks(frame, results.left_hand_landmarks,
                               mp_holistic.HAND_CONNECTIONS)
     mp_drawing.draw_landmarks(frame, results.right_hand_landmarks,
                               mp_holistic.HAND_CONNECTIONS)
     mp_drawing.draw_landmarks(frame, results.pose_landmarks,
                               mp_holistic.POSE_CONNECTIONS)
     path = os.path.join(
         tempfile.gettempdir(),
         self.id().split('.')[-1] + '_frame_{}.png'.format(idx))
     cv2.imwrite(path, frame)
# limitations under the License.

"""Tests for mediapipe.python.solutions.drawing_utils."""

from absl.testing import absltest
from absl.testing import parameterized
import cv2
import numpy as np

from google.protobuf import text_format

from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python.solutions import drawing_utils

DEFAULT_BBOX_DRAWING_SPEC = drawing_utils.DrawingSpec()
DEFAULT_CONNECTION_DRAWING_SPEC = drawing_utils.DrawingSpec()
DEFAULT_CIRCLE_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255))
DEFAULT_AXIS_DRAWING_SPEC = drawing_utils.DrawingSpec()


class DrawingUtilTest(parameterized.TestCase):

  def test_invalid_input_image(self):
    image = np.arange(18, dtype=np.uint8).reshape(3, 3, 2)
    with self.assertRaisesRegex(
        ValueError, 'Input image must contain three channel rgb data.'):
      drawing_utils.draw_landmarks(image, landmark_pb2.NormalizedLandmarkList())
    with self.assertRaisesRegex(
        ValueError, 'Input image must contain three channel rgb data.'):
      drawing_utils.draw_detection(image, detection_pb2.Detection())
예제 #5
0
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for mediapipe.python.solutions.drawing_utils."""

from absl.testing import absltest
from absl.testing import parameterized
import cv2
import numpy as np

from google.protobuf import text_format

from mediapipe.framework.formats import landmark_pb2
from mediapipe.python.solutions import drawing_utils

DEFAULT_CONNECTION_DRAWING_SPEC = drawing_utils.DrawingSpec()
DEFAULT_LANDMARK_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255))


class DrawingUtilTest(parameterized.TestCase):

  def test_invalid_input_image(self):
    image = np.arange(18, dtype=np.uint8).reshape(3, 3, 2)
    with self.assertRaisesRegex(
        ValueError, 'Input image must contain three channel rgb data.'):
      drawing_utils.draw_landmarks(image, landmark_pb2.NormalizedLandmarkList())

  def test_invalid_connection(self):
    landmark_list = text_format.Parse(
        'landmark {x: 0.5 y: 0.5} landmark {x: 0.2 y: 0.2}',
        landmark_pb2.NormalizedLandmarkList())