def test_read_batch_from_dataset_tables(self): testdata_dir = 'poem/testdata' # Assume $PWD == "google_research/". table_path = os.path.join(FLAGS.test_srcdir, testdata_dir, 'tfe-2.tfrecords') inputs = pipeline_utils.read_batch_from_dataset_tables( [table_path, table_path], batch_sizes=[4, 2], num_instances_per_record=2, shuffle=True, num_epochs=None, keypoint_names_3d=keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_3DH36M17').keypoint_names, keypoint_names_2d=keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_2DCOCO13').keypoint_names, seed=0) self.assertCountEqual(inputs.keys(), [ 'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'keypoint_masks_2d', 'keypoints_3d' ]) self.assertEqual(inputs['image_sizes'].shape, [6, 2, 2]) self.assertEqual(inputs['keypoints_2d'].shape, [6, 2, 13, 2]) self.assertEqual(inputs['keypoint_scores_2d'].shape, [6, 2, 13]) self.assertEqual(inputs['keypoint_masks_2d'].shape, [6, 2, 13]) self.assertEqual(inputs['keypoints_3d'].shape, [6, 2, 17, 3])
def test_random_project_and_select_keypoints(self): keypoints_3d = tf.constant([ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], [7.0, 7.0, 7.0], [8.0, 8.0, 8.0], [9.0, 9.0, 9.0], [10.0, 10.0, 10.0], [11.0, 11.0, 11.0], [12.0, 12.0, 12.0], [13.0, 13.0, 13.0], [14.0, 14.0, 14.0], [15.0, 15.0, 15.0], [16.0, 16.0, 16.0], ]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17') ) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13') ) keypoints_2d, _ = keypoint_utils.random_project_and_select_keypoints( keypoints_3d, keypoint_profile_3d=keypoint_profile_3d, output_keypoint_names=( keypoint_profile_2d. compatible_keypoint_name_dict['LEGACY_3DH36M17']), azimuth_range=(math.pi / 2.0, math.pi / 2.0), elevation_range=(math.pi / 2.0, math.pi / 2.0), roll_range=(-math.pi / 2.0, -math.pi / 2.0)) keypoints_2d, _, _ = keypoint_profile_2d.normalize(keypoints_2d) self.assertAllClose(keypoints_2d, [ [-0.4356161, 0.4356161], [-0.32822642, 0.32822642], [-0.2897728, 0.28977284], [-0.24986516, 0.24986516], [-0.2084193, 0.2084193], [-0.16534455, 0.16534461], [-0.12054307, 0.1205431], [-0.025327, 0.025327], [0.025327, -0.025327], [0.07818867, -0.07818867], [0.13340548, -0.13340548], [0.19113854, -0.19113848], [0.2515637, -0.25156358], ])
def test_randomly_project_and_select_keypoints(self): keypoints_3d = tf.constant([ [2.0, 1.0, 3.0], # HEAD. [2.01, 1.01, 3.01], # NECK. [2.02, 1.02, 3.02], # LEFT_SHOULDER. [2.03, 1.03, 3.03], # RIGHT_SHOULDER. [2.04, 1.04, 3.04], # LEFT_ELBOW. [2.05, 1.05, 3.05], # RIGHT_ELBOW. [2.06, 1.06, 3.06], # LEFT_WRIST. [2.07, 1.07, 3.07], # RIGHT_WRIST. [2.08, 1.08, 3.08], # SPINE. [2.09, 1.09, 3.09], # PELVIS. [2.10, 1.10, 3.10], # LEFT_HIP. [2.11, 1.11, 3.11], # RIGHT_HIP. [2.12, 1.12, 3.12], # LEFT_KNEE. [2.13, 1.13, 3.13], # RIGHT_KNEE. [2.14, 1.14, 3.14], # LEFT_ANKLE. [2.15, 1.15, 3.15], # RIGHT_ANKLE. ]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('3DSTD16')) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('2DSTD13')) keypoints_2d, _ = keypoint_utils.randomly_project_and_select_keypoints( keypoints_3d, keypoint_profile_3d=keypoint_profile_3d, output_keypoint_names=( keypoint_profile_2d.compatible_keypoint_name_dict['3DSTD16']), azimuth_range=(math.pi / 2.0, math.pi / 2.0), elevation_range=(-math.pi / 2.0, -math.pi / 2.0), roll_range=(math.pi, math.pi), normalized_camera_depth_range=(2.0, 2.0), normalize_before_projection=False) self.assertAllClose( keypoints_2d, [ [-1.0 / 4.0, -3.0 / 4.0], # NOSE_TIP [-1.02 / 4.02, -3.02 / 4.02], # LEFT_SHOULDER. [-1.03 / 4.03, -3.03 / 4.03], # RIGHT_SHOULDER. [-1.04 / 4.04, -3.04 / 4.04], # LEFT_ELBOW. [-1.05 / 4.05, -3.05 / 4.05], # RIGHT_ELBOW. [-1.06 / 4.06, -3.06 / 4.06], # LEFT_WRIST. [-1.07 / 4.07, -3.07 / 4.07], # RIGHT_WRIST. [-1.10 / 4.10, -3.10 / 4.10], # LEFT_HIP. [-1.11 / 4.11, -3.11 / 4.11], # RIGHT_HIP. [-1.12 / 4.12, -3.12 / 4.12], # LEFT_KNEE. [-1.13 / 4.13, -3.13 / 4.13], # RIGHT_KNEE. [-1.14 / 4.14, -3.14 / 4.14], # LEFT_ANKLE. [-1.15 / 4.15, -3.15 / 4.15], # RIGHT_ANKLE. ])
def test_select_keypoints_by_name(self): input_keypoints = tf.constant([ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], [7.0, 7.0, 7.0], [8.0, 8.0, 8.0], [9.0, 9.0, 9.0], [10.0, 10.0, 10.0], [11.0, 11.0, 11.0], [12.0, 12.0, 12.0], [13.0, 13.0, 13.0], [14.0, 14.0, 14.0], [15.0, 15.0, 15.0], [16.0, 16.0, 16.0], ]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17') ) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13') ) output_keypoints, _ = keypoint_utils.select_keypoints_by_name( input_keypoints, input_keypoint_names=keypoint_profile_3d.keypoint_names, output_keypoint_names=( keypoint_profile_2d. compatible_keypoint_name_dict['LEGACY_3DH36M17'])) self.assertAllClose(output_keypoints, [ [1.0, 1.0, 1.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], [7.0, 7.0, 7.0], [8.0, 8.0, 8.0], [9.0, 9.0, 9.0], [11.0, 11.0, 11.0], [12.0, 12.0, 12.0], [13.0, 13.0, 13.0], [14.0, 14.0, 14.0], [15.0, 15.0, 15.0], [16.0, 16.0, 16.0], ])
def test_legacy_h36m13_keypoint_profile_3d_is_correct(self): profile = keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_3DH36M13') self.assertEqual(profile.name, 'LEGACY_3DH36M13') self.assertEqual(profile.keypoint_dim, 3) self.assertEqual(profile.keypoint_num, 13) self.assertEqual(profile.keypoint_names, [ 'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist', 'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot' ]) self.assertEqual( profile.keypoint_left_right_type(1), keypoint_profiles.LeftRightType.LEFT) self.assertEqual( profile.segment_left_right_type(1, 4), keypoint_profiles.LeftRightType.CENTRAL) self.assertEqual(profile.offset_keypoint_index, [7]) self.assertEqual(profile.scale_keypoint_index_pairs, [([7, 8], [1, 2])]) self.assertEqual(profile.keypoint_index('LShoulder'), 1) self.assertEqual(profile.keypoint_index('dummy'), -1) self.assertEqual(profile.segment_index_pairs, [([7, 8], [1, 2]), ([7, 8], [7]), ([7, 8], [8]), ([7], [9]), ([8], [10]), ([9], [11]), ([10], [12]), ([1, 2], [0]), ([1, 2], [1]), ([1, 2], [2]), ([1], [3]), ([2], [4]), ([3], [5]), ([4], [6])]) self.assertAllEqual(profile.keypoint_affinity_matrix, [ [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0], [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], ]) self.assertEqual(profile.head_keypoint_index, [0]) self.assertEqual(profile.neck_keypoint_index, [1, 2]) self.assertEqual(profile.left_shoulder_keypoint_index, [1]) self.assertEqual(profile.right_shoulder_keypoint_index, [2]) self.assertEqual(profile.left_elbow_keypoint_index, [3]) self.assertEqual(profile.right_elbow_keypoint_index, [4]) self.assertEqual(profile.left_wrist_keypoint_index, [5]) self.assertEqual(profile.right_wrist_keypoint_index, [6]) self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8]) self.assertEqual(profile.pelvis_keypoint_index, [7, 8]) self.assertEqual(profile.left_hip_keypoint_index, [7]) self.assertEqual(profile.right_hip_keypoint_index, [8]) self.assertEqual(profile.left_knee_keypoint_index, [9]) self.assertEqual(profile.right_knee_keypoint_index, [10]) self.assertEqual(profile.left_ankle_keypoint_index, [11]) self.assertEqual(profile.right_ankle_keypoint_index, [12])
def test_std13_keypoint_profile_3d_is_correct(self): profile = keypoint_profiles.create_keypoint_profile_or_die('3DSTD13') self.assertEqual(profile.name, '3DSTD13') self.assertEqual(profile.keypoint_dim, 3) self.assertEqual(profile.keypoint_num, 13) self.assertEqual(profile.keypoint_names, [ 'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE' ]) self.assertEqual(profile.keypoint_left_right_type(1), keypoint_profiles.LeftRightType.LEFT) self.assertEqual(profile.segment_left_right_type(1, 2), keypoint_profiles.LeftRightType.CENTRAL) self.assertEqual(profile.offset_keypoint_index, [7, 8]) self.assertEqual(profile.scale_keypoint_index_pairs, [([1, 2], [7, 8])]) self.assertEqual(profile.keypoint_index('LEFT_SHOULDER'), 1) self.assertEqual(profile.keypoint_index('dummy'), -1) self.assertEqual(profile.segment_index_pairs, [([0], [1, 2]), ([1, 2], [1]), ([1, 2], [2]), ([1, 2], [1, 2, 7, 8]), ([1], [3]), ([2], [4]), ([3], [5]), ([4], [6]), ([1, 2, 7, 8], [7, 8]), ([7, 8], [7]), ([7, 8], [8]), ([7], [9]), ([8], [10]), ([9], [11]), ([10], [12])]) self.assertAllEqual(profile.keypoint_affinity_matrix, [ [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0], [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], ]) self.assertEqual(profile.head_keypoint_index, [0]) self.assertEqual(profile.neck_keypoint_index, [1, 2]) self.assertEqual(profile.left_shoulder_keypoint_index, [1]) self.assertEqual(profile.right_shoulder_keypoint_index, [2]) self.assertEqual(profile.left_elbow_keypoint_index, [3]) self.assertEqual(profile.right_elbow_keypoint_index, [4]) self.assertEqual(profile.left_wrist_keypoint_index, [5]) self.assertEqual(profile.right_wrist_keypoint_index, [6]) self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8]) self.assertEqual(profile.pelvis_keypoint_index, [7, 8]) self.assertEqual(profile.left_hip_keypoint_index, [7]) self.assertEqual(profile.right_hip_keypoint_index, [8]) self.assertEqual(profile.left_knee_keypoint_index, [9]) self.assertEqual(profile.right_knee_keypoint_index, [10]) self.assertEqual(profile.left_ankle_keypoint_index, [11]) self.assertEqual(profile.right_ankle_keypoint_index, [12])
def main(_): """Runs inference.""" keypoint_profile_2d = (keypoint_profiles.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d)) g = tf.Graph() with g.as_default(): keypoints_2d, keypoint_masks_2d = read_inputs(keypoint_profile_2d) model_inputs, _ = input_generator.create_model_input( keypoints_2d, keypoint_masks_2d=keypoint_masks_2d, keypoints_3d=None, model_input_keypoint_type=common. MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT, model_input_keypoint_mask_type=FLAGS. model_input_keypoint_mask_type, keypoint_profile_2d=keypoint_profile_2d, # Fix seed for determinism. seed=1) embedder_fn = models.get_embedder( base_model_type=FLAGS.base_model_type, embedding_type=FLAGS.embedding_type, num_embedding_components=FLAGS.num_embedding_components, embedding_size=FLAGS.embedding_size, num_embedding_samples=FLAGS.num_embedding_samples, is_training=False, num_fc_blocks=FLAGS.num_fc_blocks, num_fcs_per_block=FLAGS.num_fcs_per_block, num_hidden_nodes=FLAGS.num_hidden_nodes, num_bottleneck_nodes=FLAGS.num_bottleneck_nodes, weight_max_norm=FLAGS.weight_max_norm) outputs, _ = embedder_fn(model_inputs) if FLAGS.use_moving_average: variables_to_restore = ( pipeline_utils.get_moving_average_variables_to_restore()) saver = tf.train.Saver(variables_to_restore) else: saver = tf.train.Saver() scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer(), saver=saver) session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, master=FLAGS.master, checkpoint_filename_with_path=FLAGS.checkpoint_path) with tf.train.MonitoredSession(session_creator=session_creator, hooks=None) as sess: outputs_result = sess.run(outputs) tf.gfile.MakeDirs(FLAGS.output_dir) for key in [ common.KEY_EMBEDDING_MEANS, common.KEY_EMBEDDING_STDDEVS, common.KEY_EMBEDDING_SAMPLES ]: if key in outputs_result: output = outputs_result[key] np.savetxt(os.path.join(FLAGS.output_dir, key + '.csv'), output.reshape([output.shape[0], -1]), delimiter=',')
def test_legacy_h36m13_keypoint_profile_2d_is_correct(self): profile = keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_2DH36M13') self.assertEqual(profile.name, 'LEGACY_2DH36M13') self.assertEqual(profile.keypoint_dim, 2) self.assertEqual(profile.keypoint_num, 13) self.assertEqual(profile.keypoint_names, [ 'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist', 'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot' ]) self.assertEqual(profile.keypoint_left_right_type(0), keypoint_profiles.LeftRightType.CENTRAL) self.assertEqual(profile.segment_left_right_type(0, 1), keypoint_profiles.LeftRightType.LEFT) self.assertEqual(profile.offset_keypoint_index, [7, 8]) self.assertEqual(profile.scale_keypoint_index_pairs, [([1], [2]), ([1], [7]), ([1], [8]), ([2], [7]), ([2], [8]), ([7], [8])]) self.assertEqual(profile.keypoint_index('LShoulder'), 1) self.assertEqual(profile.keypoint_index('dummy'), -1) self.assertEqual(profile.segment_index_pairs, [([0], [1]), ([0], [2]), ([1], [3]), ([3], [5]), ([2], [4]), ([4], [6]), ([1], [7]), ([2], [8]), ([7], [9]), ([9], [11]), ([8], [10]), ([10], [12]), ([1], [2]), ([7], [8])]) self.assertAllEqual(profile.keypoint_affinity_matrix, [ [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], [1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], ]) self.assertEqual( profile.compatible_keypoint_name_dict, { '3DSTD16': [ 'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE' ], '3DSTD13': [ 'HEAD', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE' ], 'LEGACY_3DH36M17': [ 'Head', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist', 'RWrist', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot' ], 'LEGACY_3DMPII3DHP17': [ 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' ] }) self.assertEqual(profile.head_keypoint_index, [0]) self.assertEqual(profile.neck_keypoint_index, [1, 2]) self.assertEqual(profile.left_shoulder_keypoint_index, [1]) self.assertEqual(profile.right_shoulder_keypoint_index, [2]) self.assertEqual(profile.left_elbow_keypoint_index, [3]) self.assertEqual(profile.right_elbow_keypoint_index, [4]) self.assertEqual(profile.left_wrist_keypoint_index, [5]) self.assertEqual(profile.right_wrist_keypoint_index, [6]) self.assertEqual(profile.spine_keypoint_index, [1, 2, 7, 8]) self.assertEqual(profile.pelvis_keypoint_index, [7, 8]) self.assertEqual(profile.left_hip_keypoint_index, [7]) self.assertEqual(profile.right_hip_keypoint_index, [8]) self.assertEqual(profile.left_knee_keypoint_index, [9]) self.assertEqual(profile.right_knee_keypoint_index, [10]) self.assertEqual(profile.left_ankle_keypoint_index, [11]) self.assertEqual(profile.right_ankle_keypoint_index, [12])
def test_legacy_mpii3dhp17_keypoint_profile_3d_is_correct(self): profile = keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_3DMPII3DHP17') self.assertEqual(profile.name, 'LEGACY_3DMPII3DHP17') self.assertEqual(profile.keypoint_dim, 3) self.assertEqual(profile.keypoint_num, 17) self.assertEqual(profile.keypoint_names, [ 'pelvis', 'head', 'neck', 'head_top', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'spine', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' ]) self.assertEqual(profile.keypoint_left_right_type(1), keypoint_profiles.LeftRightType.CENTRAL) self.assertEqual(profile.segment_left_right_type(1, 4), keypoint_profiles.LeftRightType.LEFT) self.assertEqual(profile.offset_keypoint_index, [0]) self.assertEqual(profile.scale_keypoint_index_pairs, [([0], [10]), ([10], [2])]) self.assertEqual(profile.keypoint_index('left_shoulder'), 4) self.assertEqual(profile.keypoint_index('dummy'), -1) self.assertEqual(profile.segment_index_pairs, [([0], [10]), ([0], [11]), ([0], [12]), ([10], [2]), ([11], [13]), ([12], [14]), ([13], [15]), ([14], [16]), ([2], [1]), ([2], [4]), ([2], [5]), ([1], [3]), ([4], [6]), ([5], [7]), ([6], [8]), ([7], [9])]) self.assertAllEqual(profile.keypoint_affinity_matrix, [ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], ]) self.assertEqual(profile.head_keypoint_index, [1]) self.assertEqual(profile.neck_keypoint_index, [2]) self.assertEqual(profile.left_shoulder_keypoint_index, [4]) self.assertEqual(profile.right_shoulder_keypoint_index, [5]) self.assertEqual(profile.left_elbow_keypoint_index, [6]) self.assertEqual(profile.right_elbow_keypoint_index, [7]) self.assertEqual(profile.left_wrist_keypoint_index, [8]) self.assertEqual(profile.right_wrist_keypoint_index, [9]) self.assertEqual(profile.spine_keypoint_index, [10]) self.assertEqual(profile.pelvis_keypoint_index, [0]) self.assertEqual(profile.left_hip_keypoint_index, [11]) self.assertEqual(profile.right_hip_keypoint_index, [12]) self.assertEqual(profile.left_knee_keypoint_index, [13]) self.assertEqual(profile.right_knee_keypoint_index, [14]) self.assertEqual(profile.left_ankle_keypoint_index, [15]) self.assertEqual(profile.right_ankle_keypoint_index, [16])
def test_legacy_h36m17_keypoint_profile_3d_is_correct(self): profile = keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_3DH36M17') self.assertEqual(profile.name, 'LEGACY_3DH36M17') self.assertEqual(profile.keypoint_dim, 3) self.assertEqual(profile.keypoint_num, 17) self.assertEqual(profile.keypoint_names, [ 'Hip', 'Head', 'Neck/Nose', 'Thorax', 'LShoulder', 'RShoulder', 'LElbow', 'RElbow', 'LWrist', 'RWrist', 'Spine', 'LHip', 'RHip', 'LKnee', 'RKnee', 'LFoot', 'RFoot' ]) self.assertEqual(profile.keypoint_left_right_type(1), keypoint_profiles.LeftRightType.CENTRAL) self.assertEqual(profile.segment_left_right_type(1, 4), keypoint_profiles.LeftRightType.LEFT) self.assertEqual(profile.offset_keypoint_index, [0]) self.assertEqual(profile.scale_keypoint_index_pairs, [([0], [10]), ([10], [3])]) self.assertEqual(profile.keypoint_index('Thorax'), 3) self.assertEqual(profile.keypoint_index('dummy'), -1) self.assertEqual(profile.segment_index_pairs, [([0], [10]), ([0], [11]), ([0], [12]), ([10], [3]), ([11], [13]), ([12], [14]), ([13], [15]), ([14], [16]), ([3], [2]), ([3], [4]), ([3], [5]), ([2], [1]), ([4], [6]), ([5], [7]), ([6], [8]), ([7], [9])]) self.assertAllEqual(profile.keypoint_affinity_matrix, [ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], ]) self.assertEqual(profile.head_keypoint_index, [1]) self.assertEqual(profile.neck_keypoint_index, [3]) self.assertEqual(profile.left_shoulder_keypoint_index, [4]) self.assertEqual(profile.right_shoulder_keypoint_index, [5]) self.assertEqual(profile.left_elbow_keypoint_index, [6]) self.assertEqual(profile.right_elbow_keypoint_index, [7]) self.assertEqual(profile.left_wrist_keypoint_index, [8]) self.assertEqual(profile.right_wrist_keypoint_index, [9]) self.assertEqual(profile.spine_keypoint_index, [10]) self.assertEqual(profile.pelvis_keypoint_index, [0]) self.assertEqual(profile.left_hip_keypoint_index, [11]) self.assertEqual(profile.right_hip_keypoint_index, [12]) self.assertEqual(profile.left_knee_keypoint_index, [13]) self.assertEqual(profile.right_knee_keypoint_index, [14]) self.assertEqual(profile.left_ankle_keypoint_index, [15]) self.assertEqual(profile.right_ankle_keypoint_index, [16]) self.assertEqual(profile.standard_part_names, [ 'HEAD', 'NECK', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'SPINE', 'PELVIS', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE' ]) self.assertEqual(profile.get_standard_part_index('HEAD'), [1]) self.assertEqual(profile.get_standard_part_index('NECK'), [3]) self.assertEqual(profile.get_standard_part_index('LEFT_SHOULDER'), [4]) self.assertEqual(profile.get_standard_part_index('RIGHT_SHOULDER'), [5]) self.assertEqual(profile.get_standard_part_index('LEFT_ELBOW'), [6]) self.assertEqual(profile.get_standard_part_index('RIGHT_ELBOW'), [7]) self.assertEqual(profile.get_standard_part_index('LEFT_WRIST'), [8]) self.assertEqual(profile.get_standard_part_index('RIGHT_WRIST'), [9]) self.assertEqual(profile.get_standard_part_index('SPINE'), [10]) self.assertEqual(profile.get_standard_part_index('PELVIS'), [0]) self.assertEqual(profile.get_standard_part_index('LEFT_HIP'), [11]) self.assertEqual(profile.get_standard_part_index('RIGHT_HIP'), [12]) self.assertEqual(profile.get_standard_part_index('LEFT_KNEE'), [13]) self.assertEqual(profile.get_standard_part_index('RIGHT_KNEE'), [14]) self.assertEqual(profile.get_standard_part_index('LEFT_ANKLE'), [15]) self.assertEqual(profile.get_standard_part_index('RIGHT_ANKLE'), [16])
def test_transfer_keypoint_masks_case_2(self): # Shape = [2, 16]. input_keypoint_masks = tf.constant([ [ 1.0, # NOSE 1.0, # NECK 1.0, # LEFT_SHOULDER 1.0, # RIGHT_SHOULDER 0.0, # LEFT_ELBOW 1.0, # RIGHT_ELBOW 1.0, # LEFT_WRIST 0.0, # RIGHT_WRIST 0.0, # SPINE 0.0, # PELVIS 1.0, # LEFT_HIP 0.0, # RIGHT_HIP 1.0, # LEFT_KNEE 1.0, # RIGHT_KNEE 0.0, # LEFT_ANKLE 0.0, # RIGHT_ANKLE ], [ 0.0, # NOSE 0.0, # NECK 0.0, # LEFT_SHOULDER 0.0, # RIGHT_SHOULDER 1.0, # LEFT_ELBOW 0.0, # RIGHT_ELBOW 0.0, # LEFT_WRIST 1.0, # RIGHT_WRIST 1.0, # SPINE 1.0, # PELVIS 0.0, # LEFT_HIP 1.0, # RIGHT_HIP 0.0, # LEFT_KNEE 0.0, # RIGHT_KNEE 1.0, # LEFT_ANKLE 1.0, # RIGHT_ANKLE ] ]) input_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die( '3DSTD16') output_keypoint_profile = keypoint_profiles.create_keypoint_profile_or_die( '2DSTD13') # Shape = [2, 13]. output_keypoint_masks = keypoint_utils.transfer_keypoint_masks( input_keypoint_masks, input_keypoint_profile, output_keypoint_profile) self.assertAllClose( output_keypoint_masks, [ [ 1.0, # NOSE_TIP 1.0, # LEFT_SHOULDER 1.0, # RIGHT_SHOULDER 0.0, # LEFT_ELBOW 1.0, # RIGHT_ELBOW 1.0, # LEFT_WRIST 0.0, # RIGHT_WRIST 1.0, # LEFT_HIP 0.0, # RIGHT_HIP 1.0, # LEFT_KNEE 1.0, # RIGHT_KNEE 0.0, # LEFT_ANKLE 0.0, # RIGHT_ANKLE ], [ 0.0, # NOSE_TIP 0.0, # LEFT_SHOULDER 0.0, # RIGHT_SHOULDER 1.0, # LEFT_ELBOW 0.0, # RIGHT_ELBOW 0.0, # LEFT_WRIST 1.0, # RIGHT_WRIST 0.0, # LEFT_HIP 1.0, # RIGHT_HIP 0.0, # LEFT_KNEE 0.0, # RIGHT_KNEE 1.0, # LEFT_ANKLE 1.0, # RIGHT_ANKLE ] ])
def test_preprocess_keypoints_2d_with_projection(self): # Shape = [4, 2, 17, 3]. keypoints_3d = tf.constant([ [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0], [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], ]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17') ) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13') ) keypoints_2d, _ = input_generator.preprocess_keypoints_2d( keypoints_2d=None, keypoint_masks_2d=None, keypoints_3d=keypoints_3d, model_input_keypoint_type=common. MODEL_INPUT_KEYPOINT_TYPE_3D_PROJECTION, keypoint_profile_2d=keypoint_profile_2d, keypoint_profile_3d=keypoint_profile_3d, azimuth_range=(math.pi / 2.0, math.pi / 2.0), elevation_range=(-math.pi / 2.0, -math.pi / 2.0), roll_range=(math.pi, math.pi)) expected_keypoints_2d = [ [ [[-0.08777856, 0.08777854], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [-0.22400925, 0.22400923]], [[-0.03107818, 0.03107818], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.06854735, -0.06854733]], ], [ [[-0.0098562, 0.0098562], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.13565618, -0.13565615]], [[-0.00762777, 0.00762777], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.15371482, -0.1537148], [0.14394586, -0.14394584], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.14394586, -0.14394584], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.15371482, -0.1537148], [0.14394586, -0.14394584], [0.14394586, -0.14394584]], ], [ [[-0.08777856, 0.08777854], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [-0.22400925, 0.22400923]], [[-0.03107818, 0.03107818], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.06854735, -0.06854733]], ], [ [[-0.0098562, 0.0098562], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.13565618, -0.13565615]], [[-0.08777856, 0.08777854], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.22400925, 0.22400923], [0.0, 0.0], [-0.08777856, 0.08777854], [-0.1613905, 0.16139045], [-0.22400925, 0.22400923], [-0.22400925, 0.22400923]], ], ] self.assertAllClose(keypoints_2d, expected_keypoints_2d)
def test_preprocess_keypoints_2d_with_input_and_projection(self): # Shape = [4, 2, 13, 2]. keypoints_2d = tf.constant([ [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], [ [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], [ [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], ]) # Shape = [4, 2, 13]. keypoint_masks_2d = tf.constant([ [[ 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13 ], [ 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26 ]], [[ 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39 ], [ 0.40, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.50, 0.51, 0.52 ]], [[ 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.60, 0.61, 0.62, 0.63, 0.64, 0.65 ], [ 0.66, 0.67, 0.68, 0.69, 0.70, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78 ]], [[ 0.79, 0.80, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.90, 0.91 ], [ 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.00, 0.99, 0.98, 0.97, 0.96 ]], ]) # Shape = [4, 2, 17, 3]. keypoints_3d = tf.constant([ [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0], [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], ]) # Shape = [4, 2]. assignment = tf.constant([[True, True], [False, False], [True, False], [False, True]]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17') ) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13') ) keypoints_2d, _ = input_generator.preprocess_keypoints_2d( keypoints_2d, keypoint_masks_2d, keypoints_3d, model_input_keypoint_type=common. MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT_AND_3D_PROJECTION, keypoint_profile_2d=keypoint_profile_2d, keypoint_profile_3d=keypoint_profile_3d, azimuth_range=(math.pi / 2.0, math.pi / 2.0), elevation_range=(-math.pi / 2.0, -math.pi / 2.0), roll_range=(math.pi, math.pi), projection_mix_batch_assignment=assignment) expected_keypoints_2d = [ [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], [ [[-0.0098562, 0.0098562], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.13565618, -0.13565615]], [[-0.00762777, 0.00762777], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.15371482, -0.1537148], [0.14394586, -0.14394584], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.14394586, -0.14394584], [0.17376202, -0.173762], [0.16365208, -0.16365206], [0.15371482, -0.1537148], [0.14394586, -0.14394584], [0.14394586, -0.14394584]], ], [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], [[-0.03107818, 0.03107818], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.06854735, -0.06854733], [0.191008, -0.19100799], [0.14718375, -0.14718372], [0.10647015, -0.10647012], [0.06854735, -0.06854733], [0.06854735, -0.06854733]], ], [ [[-0.0098562, 0.0098562], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.13565618, -0.13565615], [0.1755229, -0.17552288], [0.16192658, -0.16192657], [0.14864118, -0.14864117], [0.13565618, -0.13565615], [0.13565618, -0.13565615]], [[41.0, 42.0], [43.0, 44.0], [45.0, 46.0], [47.0, 48.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1.0, 2.0]], ], ] self.assertAllClose(keypoints_2d, expected_keypoints_2d)
def test_preprocess_keypoints_2d_with_projection(self): # Shape = [4, 2, 17, 3]. keypoints_3d = tf.constant([ [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[41.0, 42.0, 43.0], [43.0, 44.0, 35.0], [45.0, 46.0, 47.0], [47.0, 48.0, 49.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[11.0, 12.0, 13.0], [13.0, 14.0, 15.0], [15.0, 16.0, 17.0], [17.0, 18.0, 19.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], [[[31.0, 32.0, 33.0], [33.0, 34.0, 35.0], [35.0, 36.0, 37.0], [37.0, 38.0, 39.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]], [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [7.0, 8.0, 9.0], [7.0, 8.0, 9.0]]], ]) keypoint_profile_3d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17') ) keypoint_profile_2d = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_2DCOCO13') ) keypoints_2d, _ = input_generator.preprocess_keypoints_2d( keypoints_2d=None, keypoint_masks_2d=None, keypoints_3d=keypoints_3d, model_input_keypoint_type=common. MODEL_INPUT_KEYPOINT_TYPE_3D_PROJECTION, keypoint_profile_2d=keypoint_profile_2d, keypoint_profile_3d=keypoint_profile_3d, azimuth_range=(math.pi / 2.0, math.pi / 2.0), elevation_range=(-math.pi / 2.0, -math.pi / 2.0), roll_range=(math.pi, math.pi)) # Note that the results here were copied from test output; this test is # mainly meant for protecting the executability testing batch mixing. The # actual projection accuracy is tested separately. expected_keypoints_2d = [[[[-0.08777856, -0.08777856], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [-0.22400928, -0.22400929]], [[-0.03107818, -0.03107818], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.10647015, 0.10647015], [0.06854735, 0.06854735], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.06854735, 0.06854735], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.10647015, 0.10647015], [0.06854735, 0.06854735], [0.06854735, 0.06854735]]], [[[-0.0098562, -0.0098562], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.14864118, 0.1486412], [0.13565616, 0.13565616], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.13565616, 0.13565616], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.14864118, 0.1486412], [0.13565616, 0.13565616], [0.13565616, 0.13565616]], [[-0.00734754, 0.02939016], [0.17376201, 0.17376202], [0.16365208, 0.16365209], [0.15371482, 0.15371484], [0.14394586, 0.14394587], [0.17376201, 0.17376202], [0.16365208, 0.16365209], [0.14394586, 0.14394587], [0.17376201, 0.17376202], [0.16365208, 0.16365209], [0.15371482, 0.15371484], [0.14394586, 0.14394587], [0.14394586, 0.14394587]]], [[[-0.08777856, -0.08777856], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [-0.22400928, -0.22400929]], [[-0.03107818, -0.03107818], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.10647015, 0.10647015], [0.06854735, 0.06854735], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.06854735, 0.06854735], [0.19100799, 0.19100802], [0.14718375, 0.14718376], [0.10647015, 0.10647015], [0.06854735, 0.06854735], [0.06854735, 0.06854735]]], [[[-0.0098562, -0.0098562], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.14864118, 0.1486412], [0.13565616, 0.13565616], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.13565616, 0.13565616], [0.17552288, 0.1755229], [0.16192658, 0.1619266], [0.14864118, 0.1486412], [0.13565616, 0.13565616], [0.13565616, 0.13565616]], [[-0.08777856, -0.08777856], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.22400928, -0.22400929], [0., 0.], [-0.08777856, -0.08777856], [-0.1613905, -0.1613905], [-0.22400928, -0.22400929], [-0.22400928, -0.22400929]]]] self.assertAllClose(keypoints_2d, expected_keypoints_2d)
def load_2d_keypoints_and_write_tfrecord_with_3d_keypoints( input_csv_file, keypoint_dict, output_tfrecord_file, read_csv_pairs, num_shards): """Loads 2D keypoints from a CSV file and write TFRecord with 3D poses. The TFRecord written contains the 2D keypoints with corresponding 3D keypoints stored in keypoint_dict. Args: input_csv_file: A string of the CSV file name containing 2D keypoints to load with subjects, actions and timestamps to be matched to 3D keypoints from keypoint_dict. keypoint_dict: A dictionary for loaded 3D keypoints. Keys are (subject, action) and values are of shape [sequence_length, num_keypoints, 3]. output_tfrecord_file: A string of output filename for the TFRecord containing 2D and 3D keypoints. read_csv_pairs: A boolean that is True when each row of the CSV file stores paired entried and is False when the row contains a single entry. num_shards: An integer for the number of shards in the output TFRecord file. """ # Read the first row of the file as the header. read_header = True keypoint_profile_h36m17 = ( keypoint_profiles.create_keypoint_profile_or_die('LEGACY_3DH36M17')) tfrecord_writers = [] if num_shards > 1: for i in range(num_shards): output_tfrecord_file_sharded = ( output_tfrecord_file + '-{:05d}-of-{:05d}'.format(i, num_shards)) writer = tf.python_io.TFRecordWriter(output_tfrecord_file_sharded) tfrecord_writers.append(writer) else: writer = tf.python_io.TFRecordWriter(output_tfrecord_file) tfrecord_writers.append(writer) with tf.io.gfile.GFile(input_csv_file, 'r') as csv_rows: for shard_counter, row in enumerate(csv_rows): writer = tfrecord_writers[shard_counter % num_shards] row = row.split(',') feature_size = len(row) if read_csv_pairs: feature_size = len(row) // 2 if len(row) != feature_size * 2 or len(row) % 2 != 0: raise ValueError('CSV row has length {} but it should have an even' 'number of elements.'.format(len(row))) if read_header: read_header = False # Keep the first half of the row as header if the csv file contains # pairs. Otherwise, keep the full row. headers = row[:feature_size] # Add 3D pose headers using the keypoint names to the header list. prefix = 'image/object/part_3d/' suffix = '/center/' for name in keypoint_profile_h36m17.keypoint_names: headers.append(prefix + name + suffix + 'x') headers.append(prefix + name + suffix + 'y') headers.append(prefix + name + suffix + 'z') continue anchor_subject = row[0] anchor_action = row[1] anchor_frame_index = int(row[3]) # Replace names to be consistent with H5 file names. anchor_action = anchor_action.replace('TakingPhoto', 'Photo').replace( 'WalkingDog', 'WalkDog') # Obtain matching 3D keypoints from keypoint_dict. anchor_keypoint_3d = keypoint_dict[( anchor_subject, anchor_action)][anchor_frame_index, KEYPOINT_3D_INDICES_H5, :].reshape([-1]) # If we need to read csv pairs, the second element in the pair in the row # is the positive match. if read_csv_pairs: positive_subject = row[feature_size] positive_action = row[feature_size + 1] positive_frame_index = int(row[feature_size + 3]) positive_action = positive_action.replace('TakingPhoto', 'Photo').replace( 'WalkingDog', 'WalkDog') positive_keypoint_3d = keypoint_dict[( positive_subject, positive_action)][positive_frame_index, KEYPOINT_3D_INDICES_H5, :].reshape([-1]) # Concatenate 3D keypoints into current row with 2D keypoints. row_with_3d_keypoints = np.concatenate( (row[:feature_size], anchor_keypoint_3d, row[feature_size:], positive_keypoint_3d)) else: row_with_3d_keypoints = np.concatenate((row, anchor_keypoint_3d)) serialized_example = create_serialized_example_with_2d_3d_keypoints( row_with_3d_keypoints, headers, write_pairs=read_csv_pairs) writer.write(serialized_example)
def _validate_and_setup(common_module, keypoint_distance_config_override): """Validates and sets up training configurations.""" # Set default values for unspecified flags. if FLAGS.use_normalized_embeddings_for_triplet_mining is None: FLAGS.use_normalized_embeddings_for_triplet_mining = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss is None: FLAGS.use_normalized_embeddings_for_positive_pairwise_loss = ( FLAGS.use_normalized_embeddings_for_triplet_loss) if FLAGS.positive_pairwise_distance_type is None: FLAGS.positive_pairwise_distance_type = FLAGS.triplet_distance_type if FLAGS.positive_pairwise_distance_kernel is None: FLAGS.positive_pairwise_distance_kernel = FLAGS.triplet_distance_kernel if FLAGS.positive_pairwise_pairwise_reduction is None: FLAGS.positive_pairwise_pairwise_reduction = ( FLAGS.triplet_pairwise_reduction) if FLAGS.positive_pairwise_componentwise_reduction is None: FLAGS.positive_pairwise_componentwise_reduction = ( FLAGS.triplet_componentwise_reduction) # Validate flags. validate_flag = common_module.validate validate_flag(FLAGS.model_input_keypoint_type, common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES) validate_flag(FLAGS.embedding_type, common_module.SUPPORTED_EMBEDDING_TYPES) validate_flag(FLAGS.keypoint_distance_type, common_module.SUPPORTED_KEYPOINT_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.triplet_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.triplet_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.triplet_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_distance_type, common_module.SUPPORTED_DISTANCE_TYPES) validate_flag(FLAGS.positive_pairwise_distance_kernel, common_module.SUPPORTED_DISTANCE_KERNELS) validate_flag(FLAGS.positive_pairwise_pairwise_reduction, common_module.SUPPORTED_PAIRWISE_DISTANCE_REDUCTIONS) validate_flag(FLAGS.positive_pairwise_componentwise_reduction, common_module.SUPPORTED_COMPONENTWISE_DISTANCE_REDUCTIONS) if FLAGS.embedding_type == common_module.EMBEDDING_TYPE_POINT: if FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]: raise ValueError( 'No support for triplet distance type `%s` for embedding type `%s`.' % (FLAGS.triplet_distance_type, FLAGS.embedding_type)) if FLAGS.kl_regularization_loss_weight > 0.0: raise ValueError( 'No support for KL regularization loss for embedding type `%s`.' % FLAGS.embedding_type) if ((FLAGS.triplet_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ] or FLAGS.positive_pairwise_distance_type in [ common_module.DISTANCE_TYPE_SAMPLE, common_module.DISTANCE_TYPE_CENTER_AND_SAMPLE ]) and FLAGS.num_embedding_samples <= 0): raise ValueError( 'Must specify positive `num_embedding_samples` to use `%s` ' 'triplet/positive pairwise distance type.' % FLAGS.triplet_distance_type) if (((FLAGS.triplet_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.triplet_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ])) or ((FLAGS.positive_pairwise_distance_kernel in [ common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB, common_module.DISTANCE_KERNEL_EXPECTED_LIKELIHOOD ]) != (FLAGS.positive_pairwise_pairwise_reduction in [ common_module.DISTANCE_REDUCTION_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_LOWER_HALF_NEG_LOG_MEAN, common_module.DISTANCE_REDUCTION_ONE_MINUS_MEAN ]))): raise ValueError( 'Must use `L2_SIGMOID_MATCHING_PROB` or `EXPECTED_LIKELIHOOD` distance ' 'kernel and `NEG_LOG_MEAN` or `LOWER_HALF_NEG_LOG_MEAN` parwise reducer' ' in pairs.') # Set up configurations. configs = { 'keypoint_profile_3d': keypoint_profiles.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'keypoint_profile_2d': keypoint_profiles.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_2d), 'target_keypoint_profile_3d': keypoint_profiles.create_keypoint_profile_or_die( FLAGS.input_keypoint_profile_name_3d), 'triplet_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, common_module=common_module), 'triplet_mining_embedding_keys': pipeline_utils.get_embedding_keys(FLAGS.triplet_distance_type, replace_samples_with_means=True, common_module=common_module), 'triplet_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.triplet_distance_kernel, pairwise_reduction=FLAGS.triplet_pairwise_reduction, componentwise_reduction=FLAGS.triplet_componentwise_reduction, # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'positive_pairwise_embedding_keys': pipeline_utils.get_embedding_keys( FLAGS.positive_pairwise_distance_type, common_module=common_module), 'positive_pairwise_embedding_sample_distance_fn': loss_utils.create_sample_distance_fn( pair_type=common_module.DISTANCE_PAIR_TYPE_ALL_PAIRS, distance_kernel=FLAGS.positive_pairwise_distance_kernel, pairwise_reduction=FLAGS.positive_pairwise_pairwise_reduction, componentwise_reduction=( FLAGS.positive_pairwise_componentwise_reduction), # We initialize the sigmoid parameters to avoid model being stuck # in a `dead zone` at the beginning of training. L2_SIGMOID_MATCHING_PROB_a_initializer=( tf.initializers.constant(-0.65)), L2_SIGMOID_MATCHING_PROB_b_initializer=( tf.initializers.constant(-0.5)), EXPECTED_LIKELIHOOD_min_stddev=0.1, EXPECTED_LIKELIHOOD_max_squared_mahalanobis_distance=100.0), 'summarize_matching_sigmoid_vars': FLAGS.triplet_distance_kernel in [common_module.DISTANCE_KERNEL_L2_SIGMOID_MATCHING_PROB], 'random_projection_azimuth_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_azimuth_range ], 'random_projection_elevation_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_elevation_range ], 'random_projection_roll_range': [ float(x) / 180.0 * math.pi for x in FLAGS.random_projection_roll_range ], } if FLAGS.keypoint_distance_type == common_module.KEYPOINT_DISTANCE_TYPE_MPJPE: configs.update({ 'keypoint_distance_fn': keypoint_utils.compute_procrustes_aligned_mpjpes, 'min_negative_keypoint_distance': FLAGS.min_negative_keypoint_mpjpe }) # We use the following assignments to get around pytype check failures. # TODO(liuti): Figure out a better workaround. if 'keypoint_distance_fn' in keypoint_distance_config_override: configs['keypoint_distance_fn'] = ( keypoint_distance_config_override['keypoint_distance_fn']) if 'min_negative_keypoint_distance' in keypoint_distance_config_override: configs['min_negative_keypoint_distance'] = ( keypoint_distance_config_override['min_negative_keypoint_distance'] ) if ('keypoint_distance_fn' not in configs or 'min_negative_keypoint_distance' not in configs): raise ValueError('Invalid keypoint distance config: %s.' % str(configs)) if FLAGS.task == 0 and not FLAGS.profile_only: # Save all key flags. pipeline_utils.create_dir_and_save_flags(flags, FLAGS.train_log_dir, 'all_flags.train.json') return configs