def test_read_batch_from_dataset_tables(self): testdata_dir = 'poem/testdata' # Assume $PWD == "google_research/". table_path = os.path.join(FLAGS.test_srcdir, testdata_dir, 'tfe-2.tfrecords') inputs = pipeline_utils.read_batch_from_dataset_tables( [table_path, table_path], batch_sizes=[4, 2], num_instances_per_record=2, shuffle=True, num_epochs=None, keypoint_names_3d=keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_3DH36M17').keypoint_names, keypoint_names_2d=keypoint_profiles.create_keypoint_profile_or_die( 'LEGACY_2DCOCO13').keypoint_names, seed=0) self.assertCountEqual(inputs.keys(), [ 'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'keypoint_masks_2d', 'keypoints_3d' ]) self.assertEqual(inputs['image_sizes'].shape, [6, 2, 2]) self.assertEqual(inputs['keypoints_2d'].shape, [6, 2, 13, 2]) self.assertEqual(inputs['keypoint_scores_2d'].shape, [6, 2, 13]) self.assertEqual(inputs['keypoint_masks_2d'].shape, [6, 2, 13]) self.assertEqual(inputs['keypoints_3d'].shape, [6, 2, 17, 3])
def create_inputs(): """Creates pipeline and model inputs.""" inputs = pipeline_utils.read_batch_from_dataset_tables( FLAGS.input_table, batch_sizes=[int(x) for x in FLAGS.batch_size], num_instances_per_record=2, shuffle=True, num_epochs=None, keypoint_names_3d=configs['keypoint_profile_3d']. keypoint_names, keypoint_names_2d=configs['keypoint_profile_2d']. keypoint_names, min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d, shuffle_buffer_size=FLAGS.input_shuffle_buffer_size, common_module=common_module, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) (inputs[common_module.KEY_KEYPOINTS_3D], keypoint_preprocessor_side_outputs_3d ) = keypoint_preprocessor_3d( inputs[common_module.KEY_KEYPOINTS_3D], keypoint_profile_3d=configs['keypoint_profile_3d'], normalize_keypoints_3d=True) inputs.update(keypoint_preprocessor_side_outputs_3d) inputs['model_inputs'], side_inputs = configs[ 'create_model_input_fn']( inputs[common_module.KEY_KEYPOINTS_2D], inputs[common_module.KEY_KEYPOINT_MASKS_2D], inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D], model_input_keypoint_type=FLAGS. model_input_keypoint_type, normalize_keypoints_2d=True, keypoint_profile_2d=configs['keypoint_profile_2d'], keypoint_profile_3d=configs['keypoint_profile_3d'], azimuth_range=configs[ 'random_projection_azimuth_range'], elevation_range=configs[ 'random_projection_elevation_range'], roll_range=configs['random_projection_roll_range'], normalized_camera_depth_range=( configs['random_projection_camera_depth_range'])) data_utils.merge_dict(side_inputs, inputs) return inputs