def test_read_batch_from_dataset_tables(self):
        testdata_dir = 'poem/testdata'  # Assume $PWD == "google_research/".
        table_path = os.path.join(FLAGS.test_srcdir, testdata_dir,
                                  'tfe-2.tfrecords')
        inputs = pipeline_utils.read_batch_from_dataset_tables(
            [table_path, table_path],
            batch_sizes=[4, 2],
            num_instances_per_record=2,
            shuffle=True,
            num_epochs=None,
            keypoint_names_3d=keypoint_profiles.create_keypoint_profile_or_die(
                'LEGACY_3DH36M17').keypoint_names,
            keypoint_names_2d=keypoint_profiles.create_keypoint_profile_or_die(
                'LEGACY_2DCOCO13').keypoint_names,
            seed=0)

        self.assertCountEqual(inputs.keys(), [
            'image_sizes', 'keypoints_2d', 'keypoint_scores_2d',
            'keypoint_masks_2d', 'keypoints_3d'
        ])
        self.assertEqual(inputs['image_sizes'].shape, [6, 2, 2])
        self.assertEqual(inputs['keypoints_2d'].shape, [6, 2, 13, 2])
        self.assertEqual(inputs['keypoint_scores_2d'].shape, [6, 2, 13])
        self.assertEqual(inputs['keypoint_masks_2d'].shape, [6, 2, 13])
        self.assertEqual(inputs['keypoints_3d'].shape, [6, 2, 17, 3])
 def test_random_project_and_select_keypoints(self):
     keypoints_3d = tf.constant([
         [0.0, 0.0, 0.0],
         [1.0, 1.0, 1.0],
         [2.0, 2.0, 2.0],
         [3.0, 3.0, 3.0],
         [4.0, 4.0, 4.0],
         [5.0, 5.0, 5.0],
         [6.0, 6.0, 6.0],
         [7.0, 7.0, 7.0],
         [8.0, 8.0, 8.0],
         [9.0, 9.0, 9.0],
         [10.0, 10.0, 10.0],
         [11.0, 11.0, 11.0],
         [12.0, 12.0, 12.0],
         [13.0, 13.0, 13.0],
         [14.0, 14.0, 14.0],
         [15.0, 15.0, 15.0],
         [16.0, 16.0, 16.0],
     ])
     keypoint_profile_3d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
     )
     keypoint_profile_2d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
     )
     keypoints_2d, _ = keypoint_utils.random_project_and_select_keypoints(
         keypoints_3d,
         keypoint_profile_3d=keypoint_profile_3d,
         output_keypoint_names=(
             keypoint_profile_2d.
             compatible_keypoint_name_dict['LEGACY_3DH36M17']),
         azimuth_range=(math.pi / 2.0, math.pi / 2.0),
         elevation_range=(math.pi / 2.0, math.pi / 2.0),
         roll_range=(-math.pi / 2.0, -math.pi / 2.0))
     keypoints_2d, _, _ = keypoint_profile_2d.normalize(keypoints_2d)
     self.assertAllClose(keypoints_2d, [
         [-0.4356161, 0.4356161],
         [-0.32822642, 0.32822642],
         [-0.2897728, 0.28977284],
         [-0.24986516, 0.24986516],
         [-0.2084193, 0.2084193],
         [-0.16534455, 0.16534461],
         [-0.12054307, 0.1205431],
         [-0.025327, 0.025327],
         [0.025327, -0.025327],
         [0.07818867, -0.07818867],
         [0.13340548, -0.13340548],
         [0.19113854, -0.19113848],
         [0.2515637, -0.25156358],
     ])
 def test_randomly_project_and_select_keypoints(self):
   keypoints_3d = tf.constant([
       [2.0, 1.0, 3.0],  # HEAD.
       [2.01, 1.01, 3.01],  # NECK.
       [2.02, 1.02, 3.02],  # LEFT_SHOULDER.
       [2.03, 1.03, 3.03],  # RIGHT_SHOULDER.
       [2.04, 1.04, 3.04],  # LEFT_ELBOW.
       [2.05, 1.05, 3.05],  # RIGHT_ELBOW.
       [2.06, 1.06, 3.06],  # LEFT_WRIST.
       [2.07, 1.07, 3.07],  # RIGHT_WRIST.
       [2.08, 1.08, 3.08],  # SPINE.
       [2.09, 1.09, 3.09],  # PELVIS.
       [2.10, 1.10, 3.10],  # LEFT_HIP.
       [2.11, 1.11, 3.11],  # RIGHT_HIP.
       [2.12, 1.12, 3.12],  # LEFT_KNEE.
       [2.13, 1.13, 3.13],  # RIGHT_KNEE.
       [2.14, 1.14, 3.14],  # LEFT_ANKLE.
       [2.15, 1.15, 3.15],  # RIGHT_ANKLE.
   ])
   keypoint_profile_3d = (
       keypoint_profiles.create_keypoint_profile_or_die('3DSTD16'))
   keypoint_profile_2d = (
       keypoint_profiles.create_keypoint_profile_or_die('2DSTD13'))
   keypoints_2d, _ = keypoint_utils.randomly_project_and_select_keypoints(
       keypoints_3d,
       keypoint_profile_3d=keypoint_profile_3d,
       output_keypoint_names=(
           keypoint_profile_2d.compatible_keypoint_name_dict['3DSTD16']),
       azimuth_range=(math.pi / 2.0, math.pi / 2.0),
       elevation_range=(-math.pi / 2.0, -math.pi / 2.0),
       roll_range=(math.pi, math.pi),
       normalized_camera_depth_range=(2.0, 2.0),
       normalize_before_projection=False)
   self.assertAllClose(
       keypoints_2d,
       [
           [-1.0 / 4.0, -3.0 / 4.0],  # NOSE_TIP
           [-1.02 / 4.02, -3.02 / 4.02],  # LEFT_SHOULDER.
           [-1.03 / 4.03, -3.03 / 4.03],  # RIGHT_SHOULDER.
           [-1.04 / 4.04, -3.04 / 4.04],  # LEFT_ELBOW.
           [-1.05 / 4.05, -3.05 / 4.05],  # RIGHT_ELBOW.
           [-1.06 / 4.06, -3.06 / 4.06],  # LEFT_WRIST.
           [-1.07 / 4.07, -3.07 / 4.07],  # RIGHT_WRIST.
           [-1.10 / 4.10, -3.10 / 4.10],  # LEFT_HIP.
           [-1.11 / 4.11, -3.11 / 4.11],  # RIGHT_HIP.
           [-1.12 / 4.12, -3.12 / 4.12],  # LEFT_KNEE.
           [-1.13 / 4.13, -3.13 / 4.13],  # RIGHT_KNEE.
           [-1.14 / 4.14, -3.14 / 4.14],  # LEFT_ANKLE.
           [-1.15 / 4.15, -3.15 / 4.15],  # RIGHT_ANKLE.
       ])
 def test_select_keypoints_by_name(self):
     input_keypoints = tf.constant([
         [0.0, 0.0, 0.0],
         [1.0, 1.0, 1.0],
         [2.0, 2.0, 2.0],
         [3.0, 3.0, 3.0],
         [4.0, 4.0, 4.0],
         [5.0, 5.0, 5.0],
         [6.0, 6.0, 6.0],
         [7.0, 7.0, 7.0],
         [8.0, 8.0, 8.0],
         [9.0, 9.0, 9.0],
         [10.0, 10.0, 10.0],
         [11.0, 11.0, 11.0],
         [12.0, 12.0, 12.0],
         [13.0, 13.0, 13.0],
         [14.0, 14.0, 14.0],
         [15.0, 15.0, 15.0],
         [16.0, 16.0, 16.0],
     ])
     keypoint_profile_3d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
     )
     keypoint_profile_2d = (
         keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
     )
     output_keypoints, _ = keypoint_utils.select_keypoints_by_name(
         input_keypoints,
         input_keypoint_names=keypoint_profile_3d.keypoint_names,
         output_keypoint_names=(
             keypoint_profile_2d.
             compatible_keypoint_name_dict['LEGACY_3DH36M17']))
     self.assertAllClose(output_keypoints, [
         [1.0, 1.0, 1.0],
         [4.0, 4.0, 4.0],
         [5.0, 5.0, 5.0],
         [6.0, 6.0, 6.0],
         [7.0, 7.0, 7.0],
         [8.0, 8.0, 8.0],
         [9.0, 9.0, 9.0],
         [11.0, 11.0, 11.0],
         [12.0, 12.0, 12.0],
         [13.0, 13.0, 13.0],
         [14.0, 14.0, 14.0],
         [15.0, 15.0, 15.0],
         [16.0, 16.0, 16.0],
     ])
Beispiel #5
0
 def test_legacy_h36m13_keypoint_profile_3d_is_correct(self):
   profile = keypoint_profiles.create_keypoint_profile_or_die(
       'LEGACY_3DH36M13')
   self.assertEqual(profile.name, 'LEGACY_3DH36M13')
   self.assertEqual(profile.keypoint_dim, 3)
   self.assertEqual(profile.keypoint_num, 13)
   self.assertEqual(profile.keypoint_names, [
       'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist',
       'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot'
   ])
   self.assertEqual(
       profile.keypoint_left_right_type(1),
       keypoint_profiles.LeftRightType.LEFT)
   self.assertEqual(
       profile.segment_left_right_type(1, 4),
       keypoint_profiles.LeftRightType.CENTRAL)
   self.assertEqual(profile.offset_keypoint_index, [7])
   self.assertEqual(profile.scale_keypoint_index_pairs, [([7, 8], [1, 2])])
   self.assertEqual(profile.keypoint_index('LShoulder'), 1)
   self.assertEqual(profile.keypoint_index('dummy'), -1)
   self.assertEqual(profile.segment_index_pairs, [([7, 8], [1, 2]),
                                                  ([7, 8], [7]), ([7, 8], [8]),
                                                  ([7], [9]), ([8], [10]),
                                                  ([9], [11]), ([10], [12]),
                                                  ([1, 2], [0]), ([1, 2], [1]),
                                                  ([1, 2], [2]), ([1], [3]),
                                                  ([2], [4]), ([3], [5]),
                                                  ([4], [6])])
   self.assertAllEqual(profile.keypoint_affinity_matrix, [
       [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0],
       [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
   ])
   self.assertEqual(profile.head_keypoint_index, [0])
   self.assertEqual(profile.neck_keypoint_index, [1, 2])
   self.assertEqual(profile.left_shoulder_keypoint_index, [1])
   self.assertEqual(profile.right_shoulder_keypoint_index, [2])
   self.assertEqual(profile.left_elbow_keypoint_index, [3])
   self.assertEqual(profile.right_elbow_keypoint_index, [4])
   self.assertEqual(profile.left_wrist_keypoint_index, [5])
   self.assertEqual(profile.right_wrist_keypoint_index, [6])
   self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8])
   self.assertEqual(profile.pelvis_keypoint_index, [7, 8])
   self.assertEqual(profile.left_hip_keypoint_index, [7])
   self.assertEqual(profile.right_hip_keypoint_index, [8])
   self.assertEqual(profile.left_knee_keypoint_index, [9])
   self.assertEqual(profile.right_knee_keypoint_index, [10])
   self.assertEqual(profile.left_ankle_keypoint_index, [11])
   self.assertEqual(profile.right_ankle_keypoint_index, [12])
 def test_std13_keypoint_profile_3d_is_correct(self):
     profile = keypoint_profiles.create_keypoint_profile_or_die('3DSTD13')
     self.assertEqual(profile.name, '3DSTD13')
     self.assertEqual(profile.keypoint_dim, 3)
     self.assertEqual(profile.keypoint_num, 13)
     self.assertEqual(profile.keypoint_names, [
         'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW',
         'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP',
         'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE'
     ])
     self.assertEqual(profile.keypoint_left_right_type(1),
                      keypoint_profiles.LeftRightType.LEFT)
     self.assertEqual(profile.segment_left_right_type(1, 2),
                      keypoint_profiles.LeftRightType.CENTRAL)
     self.assertEqual(profile.offset_keypoint_index, [7, 8])
     self.assertEqual(profile.scale_keypoint_index_pairs,
                      [([1, 2], [7, 8])])
     self.assertEqual(profile.keypoint_index('LEFT_SHOULDER'), 1)
     self.assertEqual(profile.keypoint_index('dummy'), -1)
     self.assertEqual(profile.segment_index_pairs,
                      [([0], [1, 2]), ([1, 2], [1]), ([1, 2], [2]),
                       ([1, 2], [1, 2, 7, 8]), ([1], [3]), ([2], [4]),
                       ([3], [5]), ([4], [6]), ([1, 2, 7, 8], [7, 8]),
                       ([7, 8], [7]), ([7, 8], [8]), ([7], [9]),
                       ([8], [10]), ([9], [11]), ([10], [12])])
     self.assertAllEqual(profile.keypoint_affinity_matrix, [
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
     ])
     self.assertEqual(profile.head_keypoint_index, [0])
     self.assertEqual(profile.neck_keypoint_index, [1, 2])
     self.assertEqual(profile.left_shoulder_keypoint_index, [1])
     self.assertEqual(profile.right_shoulder_keypoint_index, [2])
     self.assertEqual(profile.left_elbow_keypoint_index, [3])
     self.assertEqual(profile.right_elbow_keypoint_index, [4])
     self.assertEqual(profile.left_wrist_keypoint_index, [5])
     self.assertEqual(profile.right_wrist_keypoint_index, [6])
     self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8])
     self.assertEqual(profile.pelvis_keypoint_index, [7, 8])
     self.assertEqual(profile.left_hip_keypoint_index, [7])
     self.assertEqual(profile.right_hip_keypoint_index, [8])
     self.assertEqual(profile.left_knee_keypoint_index, [9])
     self.assertEqual(profile.right_knee_keypoint_index, [10])
     self.assertEqual(profile.left_ankle_keypoint_index, [11])
     self.assertEqual(profile.right_ankle_keypoint_index, [12])
Beispiel #7
0
def main(_):
    """Runs inference."""
    keypoint_profile_2d = (keypoint_profiles.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d))

    g = tf.Graph()
    with g.as_default():
        keypoints_2d, keypoint_masks_2d = read_inputs(keypoint_profile_2d)

        model_inputs, _ = input_generator.create_model_input(
            keypoints_2d,
            keypoint_masks_2d=keypoint_masks_2d,
            keypoints_3d=None,
            model_input_keypoint_type=common.
            MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
            model_input_keypoint_mask_type=FLAGS.
            model_input_keypoint_mask_type,
            keypoint_profile_2d=keypoint_profile_2d,
            # Fix seed for determinism.
            seed=1)

        embedder_fn = models.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=False,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm)

        outputs, _ = embedder_fn(model_inputs)

        if FLAGS.use_moving_average:
            variables_to_restore = (
                pipeline_utils.get_moving_average_variables_to_restore())
            saver = tf.train.Saver(variables_to_restore)
        else:
            saver = tf.train.Saver()

        scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer(),
                                     saver=saver)
        session_creator = tf.train.ChiefSessionCreator(
            scaffold=scaffold,
            master=FLAGS.master,
            checkpoint_filename_with_path=FLAGS.checkpoint_path)

        with tf.train.MonitoredSession(session_creator=session_creator,
                                       hooks=None) as sess:
            outputs_result = sess.run(outputs)

    tf.gfile.MakeDirs(FLAGS.output_dir)
    for key in [
            common.KEY_EMBEDDING_MEANS, common.KEY_EMBEDDING_STDDEVS,
            common.KEY_EMBEDDING_SAMPLES
    ]:
        if key in outputs_result:
            output = outputs_result[key]
            np.savetxt(os.path.join(FLAGS.output_dir, key + '.csv'),
                       output.reshape([output.shape[0], -1]),
                       delimiter=',')
 def test_legacy_h36m13_keypoint_profile_2d_is_correct(self):
     profile = keypoint_profiles.create_keypoint_profile_or_die(
         'LEGACY_2DH36M13')
     self.assertEqual(profile.name, 'LEGACY_2DH36M13')
     self.assertEqual(profile.keypoint_dim, 2)
     self.assertEqual(profile.keypoint_num, 13)
     self.assertEqual(profile.keypoint_names, [
         'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist',
         'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot'
     ])
     self.assertEqual(profile.keypoint_left_right_type(0),
                      keypoint_profiles.LeftRightType.CENTRAL)
     self.assertEqual(profile.segment_left_right_type(0, 1),
                      keypoint_profiles.LeftRightType.LEFT)
     self.assertEqual(profile.offset_keypoint_index, [7, 8])
     self.assertEqual(profile.scale_keypoint_index_pairs, [([1], [2]),
                                                           ([1], [7]),
                                                           ([1], [8]),
                                                           ([2], [7]),
                                                           ([2], [8]),
                                                           ([7], [8])])
     self.assertEqual(profile.keypoint_index('LShoulder'), 1)
     self.assertEqual(profile.keypoint_index('dummy'), -1)
     self.assertEqual(profile.segment_index_pairs,
                      [([0], [1]), ([0], [2]), ([1], [3]), ([3], [5]),
                       ([2], [4]), ([4], [6]), ([1], [7]), ([2], [8]),
                       ([7], [9]), ([9], [11]), ([8], [10]), ([10], [12]),
                       ([1], [2]), ([7], [8])])
     self.assertAllEqual(profile.keypoint_affinity_matrix, [
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
     ])
     self.assertEqual(
         profile.compatible_keypoint_name_dict, {
             '3DSTD16': [
                 'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW',
                 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP',
                 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE',
                 'RIGHT_ANKLE'
             ],
             '3DSTD13': [
                 'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW',
                 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP',
                 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE',
                 'RIGHT_ANKLE'
             ],
             'LEGACY_3DH36M17': [
                 'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow',
                 'LWrist', 'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee',
                 'LFoot', 'RFoot'
             ],
             'LEGACY_3DMPII3DHP17': [
                 'head', 'left_shoulder', 'right_shoulder', 'left_elbow',
                 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip',
                 'right_hip', 'left_knee', 'right_knee', 'left_ankle',
                 'right_ankle'
             ]
         })
     self.assertEqual(profile.head_keypoint_index, [0])
     self.assertEqual(profile.neck_keypoint_index, [1, 2])
     self.assertEqual(profile.left_shoulder_keypoint_index, [1])
     self.assertEqual(profile.right_shoulder_keypoint_index, [2])
     self.assertEqual(profile.left_elbow_keypoint_index, [3])
     self.assertEqual(profile.right_elbow_keypoint_index, [4])
     self.assertEqual(profile.left_wrist_keypoint_index, [5])
     self.assertEqual(profile.right_wrist_keypoint_index, [6])
     self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8])
     self.assertEqual(profile.pelvis_keypoint_index, [7, 8])
     self.assertEqual(profile.left_hip_keypoint_index, [7])
     self.assertEqual(profile.right_hip_keypoint_index, [8])
     self.assertEqual(profile.left_knee_keypoint_index, [9])
     self.assertEqual(profile.right_knee_keypoint_index, [10])
     self.assertEqual(profile.left_ankle_keypoint_index, [11])
     self.assertEqual(profile.right_ankle_keypoint_index, [12])
 def test_legacy_mpii3dhp17_keypoint_profile_3d_is_correct(self):
     profile = keypoint_profiles.create_keypoint_profile_or_die(
         'LEGACY_3DMPII3DHP17')
     self.assertEqual(profile.name, 'LEGACY_3DMPII3DHP17')
     self.assertEqual(profile.keypoint_dim, 3)
     self.assertEqual(profile.keypoint_num, 17)
     self.assertEqual(profile.keypoint_names, [
         'pelvis', 'head', 'neck', 'head_top', 'left_shoulder',
         'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist',
         'right_wrist', 'spine', 'left_hip', 'right_hip', 'left_knee',
         'right_knee', 'left_ankle', 'right_ankle'
     ])
     self.assertEqual(profile.keypoint_left_right_type(1),
                      keypoint_profiles.LeftRightType.CENTRAL)
     self.assertEqual(profile.segment_left_right_type(1, 4),
                      keypoint_profiles.LeftRightType.LEFT)
     self.assertEqual(profile.offset_keypoint_index, [0])
     self.assertEqual(profile.scale_keypoint_index_pairs, [([0], [10]),
                                                           ([10], [2])])
     self.assertEqual(profile.keypoint_index('left_shoulder'), 4)
     self.assertEqual(profile.keypoint_index('dummy'), -1)
     self.assertEqual(profile.segment_index_pairs,
                      [([0], [10]), ([0], [11]), ([0], [12]), ([10], [2]),
                       ([11], [13]), ([12], [14]), ([13], [15]),
                       ([14], [16]), ([2], [1]), ([2], [4]), ([2], [5]),
                       ([1], [3]), ([4], [6]), ([5], [7]), ([6], [8]),
                       ([7], [9])])
     self.assertAllEqual(profile.keypoint_affinity_matrix, [
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
     ])
     self.assertEqual(profile.head_keypoint_index, [1])
     self.assertEqual(profile.neck_keypoint_index, [2])
     self.assertEqual(profile.left_shoulder_keypoint_index, [4])
     self.assertEqual(profile.right_shoulder_keypoint_index, [5])
     self.assertEqual(profile.left_elbow_keypoint_index, [6])
     self.assertEqual(profile.right_elbow_keypoint_index, [7])
     self.assertEqual(profile.left_wrist_keypoint_index, [8])
     self.assertEqual(profile.right_wrist_keypoint_index, [9])
     self.assertEqual(profile.spine_keypoint_index, [10])
     self.assertEqual(profile.pelvis_keypoint_index, [0])
     self.assertEqual(profile.left_hip_keypoint_index, [11])
     self.assertEqual(profile.right_hip_keypoint_index, [12])
     self.assertEqual(profile.left_knee_keypoint_index, [13])
     self.assertEqual(profile.right_knee_keypoint_index, [14])
     self.assertEqual(profile.left_ankle_keypoint_index, [15])
     self.assertEqual(profile.right_ankle_keypoint_index, [16])
 def test_legacy_h36m17_keypoint_profile_3d_is_correct(self):
     profile = keypoint_profiles.create_keypoint_profile_or_die(
         'LEGACY_3DH36M17')
     self.assertEqual(profile.name, 'LEGACY_3DH36M17')
     self.assertEqual(profile.keypoint_dim, 3)
     self.assertEqual(profile.keypoint_num, 17)
     self.assertEqual(profile.keypoint_names, [
         'Hip', 'Head', 'Neck/Nose', 'Thorax', 'LShoulder', 'RShoulder',
         'LElbow', 'RElbow', 'LWrist', 'RWrist', 'Spine', 'LHip', 'RHip',
         'LKnee', 'RKnee', 'LFoot', 'RFoot'
     ])
     self.assertEqual(profile.keypoint_left_right_type(1),
                      keypoint_profiles.LeftRightType.CENTRAL)
     self.assertEqual(profile.segment_left_right_type(1, 4),
                      keypoint_profiles.LeftRightType.LEFT)
     self.assertEqual(profile.offset_keypoint_index, [0])
     self.assertEqual(profile.scale_keypoint_index_pairs, [([0], [10]),
                                                           ([10], [3])])
     self.assertEqual(profile.keypoint_index('Thorax'), 3)
     self.assertEqual(profile.keypoint_index('dummy'), -1)
     self.assertEqual(profile.segment_index_pairs,
                      [([0], [10]), ([0], [11]), ([0], [12]), ([10], [3]),
                       ([11], [13]), ([12], [14]), ([13], [15]),
                       ([14], [16]), ([3], [2]), ([3], [4]), ([3], [5]),
                       ([2], [1]), ([4], [6]), ([5], [7]), ([6], [8]),
                       ([7], [9])])
     self.assertAllEqual(profile.keypoint_affinity_matrix, [
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
     ])
     self.assertEqual(profile.head_keypoint_index, [1])
     self.assertEqual(profile.neck_keypoint_index, [3])
     self.assertEqual(profile.left_shoulder_keypoint_index, [4])
     self.assertEqual(profile.right_shoulder_keypoint_index, [5])
     self.assertEqual(profile.left_elbow_keypoint_index, [6])
     self.assertEqual(profile.right_elbow_keypoint_index, [7])
     self.assertEqual(profile.left_wrist_keypoint_index, [8])
     self.assertEqual(profile.right_wrist_keypoint_index, [9])
     self.assertEqual(profile.spine_keypoint_index, [10])
     self.assertEqual(profile.pelvis_keypoint_index, [0])
     self.assertEqual(profile.left_hip_keypoint_index, [11])
     self.assertEqual(profile.right_hip_keypoint_index, [12])
     self.assertEqual(profile.left_knee_keypoint_index, [13])
     self.assertEqual(profile.right_knee_keypoint_index, [14])
     self.assertEqual(profile.left_ankle_keypoint_index, [15])
     self.assertEqual(profile.right_ankle_keypoint_index, [16])
     self.assertEqual(profile.standard_part_names, [
         'HEAD', 'NECK', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW',
         'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'SPINE', 'PELVIS',
         'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE',
         'RIGHT_ANKLE'
     ])
     self.assertEqual(profile.get_standard_part_index('HEAD'), [1])
     self.assertEqual(profile.get_standard_part_index('NECK'), [3])
     self.assertEqual(profile.get_standard_part_index('LEFT_SHOULDER'), [4])
     self.assertEqual(profile.get_standard_part_index('RIGHT_SHOULDER'),
                      [5])
     self.assertEqual(profile.get_standard_part_index('LEFT_ELBOW'), [6])
     self.assertEqual(profile.get_standard_part_index('RIGHT_ELBOW'), [7])
     self.assertEqual(profile.get_standard_part_index('LEFT_WRIST'), [8])
     self.assertEqual(profile.get_standard_part_index('RIGHT_WRIST'), [9])
     self.assertEqual(profile.get_standard_part_index('SPINE'), [10])
     self.assertEqual(profile.get_standard_part_index('PELVIS'), [0])
     self.assertEqual(profile.get_standard_part_index('LEFT_HIP'), [11])
     self.assertEqual(profile.get_standard_part_index('RIGHT_HIP'), [12])
     self.assertEqual(profile.get_standard_part_index('LEFT_KNEE'), [13])
     self.assertEqual(profile.get_standard_part_index('RIGHT_KNEE'), [14])
     self.assertEqual(profile.get_standard_part_index('LEFT_ANKLE'), [15])
     self.assertEqual(profile.get_standard_part_index('RIGHT_ANKLE'), [16])
def test_transfer_keypoint_masks_case_2(self):
  # Shape = [2, 16].
  input_keypoint_masks = tf.constant([
      [
          1.0,  # NOSE
          1.0,  # NECK
          1.0,  # LEFT_SHOULDER
          1.0,  # RIGHT_SHOULDER
          0.0,  # LEFT_ELBOW
          1.0,  # RIGHT_ELBOW
          1.0,  # LEFT_WRIST
          0.0,  # RIGHT_WRIST
          0.0,  # SPINE
          0.0,  # PELVIS
          1.0,  # LEFT_HIP
          0.0,  # RIGHT_HIP
          1.0,  # LEFT_KNEE
          1.0,  # RIGHT_KNEE
          0.0,  # LEFT_ANKLE
          0.0,  # RIGHT_ANKLE
      ],
      [
          0.0,  # NOSE
          0.0,  # NECK
          0.0,  # LEFT_SHOULDER
          0.0,  # RIGHT_SHOULDER
          1.0,  # LEFT_ELBOW
          0.0,  # RIGHT_ELBOW
          0.0,  # LEFT_WRIST
          1.0,  # RIGHT_WRIST
          1.0,  # SPINE
          1.0,  # PELVIS
          0.0,  # LEFT_HIP
          1.0,  # RIGHT_HIP
          0.0,  # LEFT_KNEE
          0.0,  # RIGHT_KNEE
          1.0,  # LEFT_ANKLE
          1.0,  # RIGHT_ANKLE
      ]
  ])
  input_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die(
      '3DSTD16')
  output_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die(
      '2DSTD13')
  # Shape = [2, 13].
  output_keypoint_masks = keypoint_utils.transfer_keypoint_masks(
      input_keypoint_masks, input_keypoint_profile, output_keypoint_profile)
  self.assertAllClose(
      output_keypoint_masks,
      [
          [
              1.0,  # NOSE_TIP
              1.0,  # LEFT_SHOULDER
              1.0,  # RIGHT_SHOULDER
              0.0,  # LEFT_ELBOW
              1.0,  # RIGHT_ELBOW
              1.0,  # LEFT_WRIST
              0.0,  # RIGHT_WRIST
              1.0,  # LEFT_HIP
              0.0,  # RIGHT_HIP
              1.0,  # LEFT_KNEE
              1.0,  # RIGHT_KNEE
              0.0,  # LEFT_ANKLE
              0.0,  # RIGHT_ANKLE
          ],
          [
              0.0,  # NOSE_TIP
              0.0,  # LEFT_SHOULDER
              0.0,  # RIGHT_SHOULDER
              1.0,  # LEFT_ELBOW
              0.0,  # RIGHT_ELBOW
              0.0,  # LEFT_WRIST
              1.0,  # RIGHT_WRIST
              0.0,  # LEFT_HIP
              1.0,  # RIGHT_HIP
              0.0,  # LEFT_KNEE
              0.0,  # RIGHT_KNEE
              1.0,  # LEFT_ANKLE
              1.0,  # RIGHT_ANKLE
          ]
      ])
Beispiel #12
0
    def test_preprocess_keypoints_2d_with_projection(self):
        # Shape = [4, 2, 17, 3].
        keypoints_3d = tf.constant([
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0],
              [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
        ])

        keypoint_profile_3d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
        )
        keypoint_profile_2d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
        )
        keypoints_2d, _ = input_generator.preprocess_keypoints_2d(
            keypoints_2d=None,
            keypoint_masks_2d=None,
            keypoints_3d=keypoints_3d,
            model_input_keypoint_type=common.
            MODEL_INPUT_KEYPOINT_TYPE_3D_PROJECTION,
            keypoint_profile_2d=keypoint_profile_2d,
            keypoint_profile_3d=keypoint_profile_3d,
            azimuth_range=(math.pi / 2.0, math.pi / 2.0),
            elevation_range=(-math.pi / 2.0, -math.pi / 2.0),
            roll_range=(math.pi, math.pi))

        expected_keypoints_2d = [
            [
                [[-0.08777856, 0.08777854], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.1613905, 0.16139045],
                 [-0.22400925, 0.22400923], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.22400925, 0.22400923],
                 [0.0, 0.0], [-0.08777856, 0.08777854],
                 [-0.1613905, 0.16139045], [-0.22400925, 0.22400923],
                 [-0.22400925, 0.22400923]],
                [[-0.03107818, 0.03107818], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.10647015, -0.10647012],
                 [0.06854735, -0.06854733], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.06854735, -0.06854733],
                 [0.191008, -0.19100799], [0.14718375, -0.14718372],
                 [0.10647015, -0.10647012], [0.06854735, -0.06854733],
                 [0.06854735, -0.06854733]],
            ],
            [
                [[-0.0098562, 0.0098562], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.14864118, -0.14864117],
                 [0.13565618, -0.13565615], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.13565618, -0.13565615],
                 [0.1755229, -0.17552288], [0.16192658, -0.16192657],
                 [0.14864118, -0.14864117], [0.13565618, -0.13565615],
                 [0.13565618, -0.13565615]],
                [[-0.00762777, 0.00762777], [0.17376202, -0.173762],
                 [0.16365208, -0.16365206], [0.15371482, -0.1537148],
                 [0.14394586, -0.14394584], [0.17376202, -0.173762],
                 [0.16365208, -0.16365206], [0.14394586, -0.14394584],
                 [0.17376202, -0.173762], [0.16365208, -0.16365206],
                 [0.15371482, -0.1537148], [0.14394586, -0.14394584],
                 [0.14394586, -0.14394584]],
            ],
            [
                [[-0.08777856, 0.08777854], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.1613905, 0.16139045],
                 [-0.22400925, 0.22400923], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.22400925, 0.22400923],
                 [0.0, 0.0], [-0.08777856, 0.08777854],
                 [-0.1613905, 0.16139045], [-0.22400925, 0.22400923],
                 [-0.22400925, 0.22400923]],
                [[-0.03107818, 0.03107818], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.10647015, -0.10647012],
                 [0.06854735, -0.06854733], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.06854735, -0.06854733],
                 [0.191008, -0.19100799], [0.14718375, -0.14718372],
                 [0.10647015, -0.10647012], [0.06854735, -0.06854733],
                 [0.06854735, -0.06854733]],
            ],
            [
                [[-0.0098562, 0.0098562], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.14864118, -0.14864117],
                 [0.13565618, -0.13565615], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.13565618, -0.13565615],
                 [0.1755229, -0.17552288], [0.16192658, -0.16192657],
                 [0.14864118, -0.14864117], [0.13565618, -0.13565615],
                 [0.13565618, -0.13565615]],
                [[-0.08777856, 0.08777854], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.1613905, 0.16139045],
                 [-0.22400925, 0.22400923], [0.0, 0.0],
                 [-0.08777856, 0.08777854], [-0.22400925, 0.22400923],
                 [0.0, 0.0], [-0.08777856, 0.08777854],
                 [-0.1613905, 0.16139045], [-0.22400925, 0.22400923],
                 [-0.22400925, 0.22400923]],
            ],
        ]

        self.assertAllClose(keypoints_2d, expected_keypoints_2d)
Beispiel #13
0
    def test_preprocess_keypoints_2d_with_input_and_projection(self):
        # Shape = [4, 2, 13, 2].
        keypoints_2d = tf.constant([
            [
                [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0],
                 [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
            [
                [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
            [
                [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0],
                 [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
            [
                [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
        ])
        # Shape = [4, 2, 13].
        keypoint_masks_2d = tf.constant([
            [[
                0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10,
                0.11, 0.12, 0.13
            ],
             [
                 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23,
                 0.24, 0.25, 0.26
             ]],
            [[
                0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36,
                0.37, 0.38, 0.39
            ],
             [
                 0.40, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49,
                 0.50, 0.51, 0.52
             ]],
            [[
                0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.60, 0.61, 0.62,
                0.63, 0.64, 0.65
            ],
             [
                 0.66, 0.67, 0.68, 0.69, 0.70, 0.71, 0.72, 0.73, 0.74, 0.75,
                 0.76, 0.77, 0.78
             ]],
            [[
                0.79, 0.80, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88,
                0.89, 0.90, 0.91
            ],
             [
                 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.00, 0.99,
                 0.98, 0.97, 0.96
             ]],
        ])
        # Shape = [4, 2, 17, 3].
        keypoints_3d = tf.constant([
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0],
              [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
        ])
        # Shape = [4, 2].
        assignment = tf.constant([[True, True], [False, False], [True, False],
                                  [False, True]])

        keypoint_profile_3d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
        )
        keypoint_profile_2d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
        )
        keypoints_2d, _ = input_generator.preprocess_keypoints_2d(
            keypoints_2d,
            keypoint_masks_2d,
            keypoints_3d,
            model_input_keypoint_type=common.
            MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT_AND_3D_PROJECTION,
            keypoint_profile_2d=keypoint_profile_2d,
            keypoint_profile_3d=keypoint_profile_3d,
            azimuth_range=(math.pi / 2.0, math.pi / 2.0),
            elevation_range=(-math.pi / 2.0, -math.pi / 2.0),
            roll_range=(math.pi, math.pi),
            projection_mix_batch_assignment=assignment)

        expected_keypoints_2d = [
            [
                [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0],
                 [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
            [
                [[-0.0098562, 0.0098562], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.14864118, -0.14864117],
                 [0.13565618, -0.13565615], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.13565618, -0.13565615],
                 [0.1755229, -0.17552288], [0.16192658, -0.16192657],
                 [0.14864118, -0.14864117], [0.13565618, -0.13565615],
                 [0.13565618, -0.13565615]],
                [[-0.00762777, 0.00762777], [0.17376202, -0.173762],
                 [0.16365208, -0.16365206], [0.15371482, -0.1537148],
                 [0.14394586, -0.14394584], [0.17376202, -0.173762],
                 [0.16365208, -0.16365206], [0.14394586, -0.14394584],
                 [0.17376202, -0.173762], [0.16365208, -0.16365206],
                 [0.15371482, -0.1537148], [0.14394586, -0.14394584],
                 [0.14394586, -0.14394584]],
            ],
            [
                [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0],
                 [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
                [[-0.03107818, 0.03107818], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.10647015, -0.10647012],
                 [0.06854735, -0.06854733], [0.191008, -0.19100799],
                 [0.14718375, -0.14718372], [0.06854735, -0.06854733],
                 [0.191008, -0.19100799], [0.14718375, -0.14718372],
                 [0.10647015, -0.10647012], [0.06854735, -0.06854733],
                 [0.06854735, -0.06854733]],
            ],
            [
                [[-0.0098562, 0.0098562], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.14864118, -0.14864117],
                 [0.13565618, -0.13565615], [0.1755229, -0.17552288],
                 [0.16192658, -0.16192657], [0.13565618, -0.13565615],
                 [0.1755229, -0.17552288], [0.16192658, -0.16192657],
                 [0.14864118, -0.14864117], [0.13565618, -0.13565615],
                 [0.13565618, -0.13565615]],
                [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0],
                 [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0],
                 [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]],
            ],
        ]

        self.assertAllClose(keypoints_2d, expected_keypoints_2d)
Beispiel #14
0
    def test_preprocess_keypoints_2d_with_projection(self):
        # Shape = [4, 2, 17, 3].
        keypoints_3d = tf.constant([
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0],
              [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0],
              [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
            [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0],
              [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0,
                                                 3.0], [3.0, 4.0, 5.0],
              [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]],
             [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0,
                                                 7.0], [7.0, 8.0, 9.0],
              [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0],
              [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]],
        ])

        keypoint_profile_3d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')
        )
        keypoint_profile_2d = (
            keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13')
        )
        keypoints_2d, _ = input_generator.preprocess_keypoints_2d(
            keypoints_2d=None,
            keypoint_masks_2d=None,
            keypoints_3d=keypoints_3d,
            model_input_keypoint_type=common.
            MODEL_INPUT_KEYPOINT_TYPE_3D_PROJECTION,
            keypoint_profile_2d=keypoint_profile_2d,
            keypoint_profile_3d=keypoint_profile_3d,
            azimuth_range=(math.pi / 2.0, math.pi / 2.0),
            elevation_range=(-math.pi / 2.0, -math.pi / 2.0),
            roll_range=(math.pi, math.pi))

        # Note that the results here were copied from test output; this test is
        # mainly meant for protecting the executability testing batch mixing. The
        # actual projection accuracy is tested separately.
        expected_keypoints_2d = [[[[-0.08777856, -0.08777856], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929],
                                   [-0.22400928, -0.22400929]],
                                  [[-0.03107818, -0.03107818],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.10647015, 0.10647015],
                                   [0.06854735, 0.06854735],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.06854735, 0.06854735],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.10647015, 0.10647015],
                                   [0.06854735, 0.06854735],
                                   [0.06854735, 0.06854735]]],
                                 [[[-0.0098562, -0.0098562],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.14864118, 0.1486412],
                                   [0.13565616, 0.13565616],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.13565616, 0.13565616],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.14864118, 0.1486412],
                                   [0.13565616, 0.13565616],
                                   [0.13565616, 0.13565616]],
                                  [[-0.00734754, 0.02939016],
                                   [0.17376201, 0.17376202],
                                   [0.16365208, 0.16365209],
                                   [0.15371482, 0.15371484],
                                   [0.14394586, 0.14394587],
                                   [0.17376201, 0.17376202],
                                   [0.16365208, 0.16365209],
                                   [0.14394586, 0.14394587],
                                   [0.17376201, 0.17376202],
                                   [0.16365208, 0.16365209],
                                   [0.15371482, 0.15371484],
                                   [0.14394586, 0.14394587],
                                   [0.14394586, 0.14394587]]],
                                 [[[-0.08777856, -0.08777856], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929],
                                   [-0.22400928, -0.22400929]],
                                  [[-0.03107818, -0.03107818],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.10647015, 0.10647015],
                                   [0.06854735, 0.06854735],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.06854735, 0.06854735],
                                   [0.19100799, 0.19100802],
                                   [0.14718375, 0.14718376],
                                   [0.10647015, 0.10647015],
                                   [0.06854735, 0.06854735],
                                   [0.06854735, 0.06854735]]],
                                 [[[-0.0098562, -0.0098562],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.14864118, 0.1486412],
                                   [0.13565616, 0.13565616],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.13565616, 0.13565616],
                                   [0.17552288, 0.1755229],
                                   [0.16192658, 0.1619266],
                                   [0.14864118, 0.1486412],
                                   [0.13565616, 0.13565616],
                                   [0.13565616, 0.13565616]],
                                  [[-0.08777856, -0.08777856], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.22400928, -0.22400929], [0., 0.],
                                   [-0.08777856, -0.08777856],
                                   [-0.1613905, -0.1613905],
                                   [-0.22400928, -0.22400929],
                                   [-0.22400928, -0.22400929]]]]

        self.assertAllClose(keypoints_2d, expected_keypoints_2d)
def load_2d_keypoints_and_write_tfrecord_with_3d_keypoints(
    input_csv_file, keypoint_dict, output_tfrecord_file, read_csv_pairs,
    num_shards):
  """Loads 2D keypoints from a CSV file and write TFRecord with 3D poses.

  The TFRecord written contains the 2D keypoints with corresponding 3D keypoints
  stored in keypoint_dict.

  Args:
    input_csv_file: A string of the CSV file name containing 2D keypoints to
      load with subjects, actions and timestamps to be matched to 3D keypoints
      from keypoint_dict.
    keypoint_dict: A dictionary for loaded 3D keypoints. Keys are (subject,
      action) and values are of shape [sequence_length, num_keypoints, 3].
    output_tfrecord_file: A string of output filename for the TFRecord
      containing 2D and 3D keypoints.
    read_csv_pairs: A boolean that is True when each row of the CSV file stores
      paired entried and is False when the row contains a single entry.
    num_shards: An integer for the number of shards in the output TFRecord file.
  """

  # Read the first row of the file as the header.
  read_header = True

  keypoint_profile_h36m17 = (
      keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17'))

  tfrecord_writers = []
  if num_shards > 1:
    for i in range(num_shards):
      output_tfrecord_file_sharded = (
          output_tfrecord_file + '-{:05d}-of-{:05d}'.format(i, num_shards))
      writer = tf.python_io.TFRecordWriter(output_tfrecord_file_sharded)
      tfrecord_writers.append(writer)
  else:
    writer = tf.python_io.TFRecordWriter(output_tfrecord_file)
    tfrecord_writers.append(writer)

  with tf.io.gfile.GFile(input_csv_file, 'r') as csv_rows:
    for shard_counter, row in enumerate(csv_rows):
      writer = tfrecord_writers[shard_counter % num_shards]
      row = row.split(',')

      feature_size = len(row)
      if read_csv_pairs:
        feature_size = len(row) // 2
        if len(row) != feature_size * 2 or len(row) % 2 != 0:
          raise ValueError('CSV row has length {} but it should have an even'
                           'number of elements.'.format(len(row)))
      if read_header:
        read_header = False
        # Keep the first half of the row as header if the csv file contains
        # pairs. Otherwise, keep the full row.
        headers = row[:feature_size]
        # Add 3D pose headers using the keypoint names to the header list.
        prefix = 'image/object/part_3d/'
        suffix = '/center/'
        for name in keypoint_profile_h36m17.keypoint_names:
          headers.append(prefix + name + suffix + 'x')
          headers.append(prefix + name + suffix + 'y')
          headers.append(prefix + name + suffix + 'z')
        continue

      anchor_subject = row[0]
      anchor_action = row[1]
      anchor_frame_index = int(row[3])
      # Replace names to be consistent with H5 file names.
      anchor_action = anchor_action.replace('TakingPhoto', 'Photo').replace(
          'WalkingDog', 'WalkDog')

      # Obtain matching 3D keypoints from keypoint_dict.
      anchor_keypoint_3d = keypoint_dict[(
          anchor_subject,
          anchor_action)][anchor_frame_index,
                          KEYPOINT_3D_INDICES_H5, :].reshape([-1])

      # If we need to read csv pairs, the second element in the pair in the row
      # is the positive match.
      if read_csv_pairs:
        positive_subject = row[feature_size]
        positive_action = row[feature_size + 1]
        positive_frame_index = int(row[feature_size + 3])
        positive_action = positive_action.replace('TakingPhoto',
                                                  'Photo').replace(
                                                      'WalkingDog', 'WalkDog')
        positive_keypoint_3d = keypoint_dict[(
            positive_subject,
            positive_action)][positive_frame_index,
                              KEYPOINT_3D_INDICES_H5, :].reshape([-1])

        # Concatenate 3D keypoints into current row with 2D keypoints.
        row_with_3d_keypoints = np.concatenate(
            (row[:feature_size], anchor_keypoint_3d, row[feature_size:],
             positive_keypoint_3d))
      else:
        row_with_3d_keypoints = np.concatenate((row, anchor_keypoint_3d))

      serialized_example = create_serialized_example_with_2d_3d_keypoints(
          row_with_3d_keypoints, headers, write_pairs=read_csv_pairs)
      writer.write(serialized_example)
Beispiel #16
0
def _validate_and_setup(common_module, keypoint_distance_config_override):
    """Validates and sets up training configurations."""
    # Set default values for unspecified flags.
    if FLAGS.use_normalized_embeddings_for_triplet_mining is None:
        FLAGS.use_normalized_embeddings_for_triplet_mining = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None:
        FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = (
            FLAGS.use_normalized_embeddings_for_triplet_loss)
    if FLAGS.positive_pairwise_distance_type is None:
        FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type
    if FLAGS.positive_pairwise_distance_kernel is None:
        FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel
    if FLAGS.positive_pairwise_pairwise_reduction is None:
        FLAGS.positive_pairwise_pairwise_reduction = (
            FLAGS.triplet_pairwise_reduction)
    if FLAGS.positive_pairwise_componentwise_reduction is None:
        FLAGS.positive_pairwise_componentwise_reduction = (
            FLAGS.triplet_componentwise_reduction)

    # Validate flags.
    validate_flag = common_module.validate
    validate_flag(FLAGS.model_input_keypoint_type,
                  common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)
    validate_flag(FLAGS.embedding_type,
                  common_module.SUPPORTED_EMBEDDING_TYPES)
    validate_flag(FLAGS.keypoint_distance_type,
                  common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.triplet_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.triplet_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.triplet_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_distance_type,
                  common_module.SUPPORTED_DISTANCE_TYPES)
    validate_flag(FLAGS.positive_pairwise_distance_kernel,
                  common_module.SUPPORTED_DISTANCE_KERNELS)
    validate_flag(FLAGS.positive_pairwise_pairwise_reduction,
                  common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS)
    validate_flag(FLAGS.positive_pairwise_componentwise_reduction,
                  common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS)

    if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT:
        if FLAGS.triplet_distance_type in [
                common_module.DISTANCE_TYPE_SAMPLE,
                common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
        ]:
            raise ValueError(
                'No support for triplet distance type `%s` for embedding type `%s`.'
                % (FLAGS.triplet_distance_type, FLAGS.embedding_type))
        if FLAGS.kl_regularization_loss_weight > 0.0:
            raise ValueError(
                'No support for KL regularization loss for embedding type `%s`.'
                % FLAGS.embedding_type)

    if ((FLAGS.triplet_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ] or FLAGS.positive_pairwise_distance_type in [
            common_module.DISTANCE_TYPE_SAMPLE,
            common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE
    ]) and FLAGS.num_embedding_samples <= 0):
        raise ValueError(
            'Must specify positive `num_embedding_samples` to use `%s` '
            'triplet/positive pairwise distance type.' %
            FLAGS.triplet_distance_type)

    if (((FLAGS.triplet_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.triplet_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ])) or ((FLAGS.positive_pairwise_distance_kernel in [
            common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB,
            common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD
    ]) != (FLAGS.positive_pairwise_pairwise_reduction in [
            common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN,
            common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN
    ]))):
        raise ValueError(
            'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance '
            'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer'
            ' in pairs.')

    # Set up configurations.
    configs = {
        'keypoint_profile_3d':
        keypoint_profiles.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'keypoint_profile_2d':
        keypoint_profiles.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_2d),
        'target_keypoint_profile_3d':
        keypoint_profiles.create_keypoint_profile_or_die(
            FLAGS.input_keypoint_profile_name_3d),
        'triplet_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          common_module=common_module),
        'triplet_mining_embedding_keys':
        pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type,
                                          replace_samples_with_means=True,
                                          common_module=common_module),
        'triplet_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.triplet_distance_kernel,
            pairwise_reduction=FLAGS.triplet_pairwise_reduction,
            componentwise_reduction=FLAGS.triplet_componentwise_reduction,
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'positive_pairwise_embedding_keys':
        pipeline_utils.get_embedding_keys(
            FLAGS.positive_pairwise_distance_type,
            common_module=common_module),
        'positive_pairwise_embedding_sample_distance_fn':
        loss_utils.create_sample_distance_fn(
            pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS,
            distance_kernel=FLAGS.positive_pairwise_distance_kernel,
            pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction,
            componentwise_reduction=(
                FLAGS.positive_pairwise_componentwise_reduction),
            # We initialize the sigmoid parameters to avoid model being stuck
            # in a `dead zone` at the beginning of training.
            L2_SIGMOID_MATCHING_PROB_a_initializer=(
                tf.initializers.constant(-0.65)),
            L2_SIGMOID_MATCHING_PROB_b_initializer=(
                tf.initializers.constant(-0.5)),
            EXPECTED_LIKELIHOOD_min_stddev=0.1,
            EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0),
        'summarize_matching_sigmoid_vars':
        FLAGS.triplet_distance_kernel
        in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB],
        'random_projection_azimuth_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_azimuth_range
        ],
        'random_projection_elevation_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_elevation_range
        ],
        'random_projection_roll_range': [
            float(x) / 180.0 * math.pi
            for x in FLAGS.random_projection_roll_range
        ],
    }

    if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE:
        configs.update({
            'keypoint_distance_fn':
            keypoint_utils.compute_procrustes_aligned_mpjpes,
            'min_negative_keypoint_distance':
            FLAGS.min_negative_keypoint_mpjpe
        })
    # We use the following assignments to get around pytype check failures.
    # TODO(liuti): Figure out a better workaround.
    if 'keypoint_distance_fn' in keypoint_distance_config_override:
        configs['keypoint_distance_fn'] = (
            keypoint_distance_config_override['keypoint_distance_fn'])
    if 'min_negative_keypoint_distance' in keypoint_distance_config_override:
        configs['min_negative_keypoint_distance'] = (
            keypoint_distance_config_override['min_negative_keypoint_distance']
        )
    if ('keypoint_distance_fn' not in configs
            or 'min_negative_keypoint_distance' not in configs):
        raise ValueError('Invalid keypoint distance config: %s.' %
                         str(configs))

    if FLAGS.task == 0 and not FLAGS.profile_only:
        # Save all key flags.
        pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir,
                                                 'all_flags.train.json')

    return configs