コード例 #1
0
    def testReadSequenceBatchFromTable(self):
        testdata_dir = 'poem/testdata'  # Assume $PWD == "google_research/".
        table_path = os.path.join(FLAGS.test_srcdir, testdata_dir,
                                  'tfe-1-seq.tfrecords')
        parser_fn = tfe_input_layer.create_tfe_parser(
            keypoint_names_2d=(keypoint_profiles.LegacyCoco13KeypointProfile2D(
            ).keypoint_names),
            include_keypoint_scores_2d=True,
            num_classes=6,
            use_label_confidence_as_class_target=True,
            num_objects=1,
            sequence_length=5)
        inputs = tfe_input_layer.read_from_table(
            [table_path],
            dataset_class=tf.data.TFRecordDataset,
            parser_fn=parser_fn)
        inputs = next(iter(inputs))

        self.assertCountEqual(inputs.keys(), [
            'image_sizes', 'keypoints_2d', 'keypoint_scores_2d',
            'class_targets', 'class_weights'
        ])
        self.assertEqual(inputs['image_sizes'].shape, [1, 5, 2])
        self.assertEqual(inputs['keypoints_2d'].shape, [1, 5, 13, 2])
        self.assertEqual(inputs['keypoint_scores_2d'].shape, [1, 5, 13])
        self.assertEqual(inputs['class_targets'].shape, [1, 6])
        self.assertEqual(inputs['class_weights'].shape, [1, 6])
コード例 #2
0
    def testReadBatchFromThreeTables(self):
        testdata_dir = 'poem/testdata'  # Assume $PWD == "google_research/".
        table_path = os.path.join(FLAGS.test_srcdir, testdata_dir,
                                  'tfe-2.tfrecords')
        parser_fn = tfe_input_layer.create_tfe_parser(
            keypoint_names_2d=(keypoint_profiles.LegacyCoco13KeypointProfile2D(
            ).keypoint_names),
            keypoint_names_3d=(keypoint_profiles.LegacyH36m17KeypointProfile3D(
            ).keypoint_names),
            include_keypoint_scores_2d=True,
            include_keypoint_scores_3d=False,
            num_objects=2)
        inputs = tfe_input_layer.read_batch_from_tables(
            [table_path, table_path, table_path],
            batch_sizes=[4, 2, 3],
            drop_remainder=True,
            shuffle=True,
            parser_fn=parser_fn)
        inputs = next(iter(inputs))

        self.assertCountEqual(inputs.keys(), [
            'image_sizes', 'keypoints_2d', 'keypoint_scores_2d', 'keypoints_3d'
        ])
        self.assertEqual(inputs['image_sizes'].shape, [9, 2, 2])
        self.assertEqual(inputs['image_sizes'].shape, [9, 2, 2])
        self.assertEqual(inputs['keypoints_2d'].shape, [9, 2, 13, 2])
        self.assertEqual(inputs['keypoint_scores_2d'].shape, [9, 2, 13])
        self.assertEqual(inputs['keypoints_3d'].shape, [9, 2, 17, 3])