def testReadSequenceBatchFromTable(self): testdata_dir = 'poem/testdata' # Assume $PWD == "google_research/". table_path = os.path.join(FLAGS.test_srcdir, testdata_dir, 'tfe-1-seq.tfrecords') parser_fn = tfe_input_layer.create_tfe_parser( keypoint_names_2d=(keypoint_profiles.LegacyCoco13KeypointProfile2D( ).keypoint_names), include_keypoint_scores_2d=True, num_classes=6, use_label_confidence_as_class_target=True, num_objects=1, sequence_length=5) inputs = tfe_input_layer.read_from_table( [table_path], dataset_class=tf.data.TFRecordDataset, parser_fn=parser_fn) inputs = next(iter(inputs)) self.assertCountEqual(inputs.keys(), [ 'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'class_targets', 'class_weights' ]) self.assertEqual(inputs['image_sizes'].shape, [1, 5, 2]) self.assertEqual(inputs['keypoints_2d'].shape, [1, 5, 13, 2]) self.assertEqual(inputs['keypoint_scores_2d'].shape, [1, 5, 13]) self.assertEqual(inputs['class_targets'].shape, [1, 6]) self.assertEqual(inputs['class_weights'].shape, [1, 6])
def testReadBatchFromThreeTables(self): testdata_dir = 'poem/testdata' # Assume $PWD == "google_research/". table_path = os.path.join(FLAGS.test_srcdir, testdata_dir, 'tfe-2.tfrecords') parser_fn = tfe_input_layer.create_tfe_parser( keypoint_names_2d=(keypoint_profiles.LegacyCoco13KeypointProfile2D( ).keypoint_names), keypoint_names_3d=(keypoint_profiles.LegacyH36m17KeypointProfile3D( ).keypoint_names), include_keypoint_scores_2d=True, include_keypoint_scores_3d=False, num_objects=2) inputs = tfe_input_layer.read_batch_from_tables( [table_path, table_path, table_path], batch_sizes=[4, 2, 3], drop_remainder=True, shuffle=True, parser_fn=parser_fn) inputs = next(iter(inputs)) self.assertCountEqual(inputs.keys(), [ 'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'keypoints_3d' ]) self.assertEqual(inputs['image_sizes'].shape, [9, 2, 2]) self.assertEqual(inputs['image_sizes'].shape, [9, 2, 2]) self.assertEqual(inputs['keypoints_2d'].shape, [9, 2, 13, 2]) self.assertEqual(inputs['keypoint_scores_2d'].shape, [9, 2, 13]) self.assertEqual(inputs['keypoints_3d'].shape, [9, 2, 17, 3])