def test_read_video_level_input(self, include_video_id): params = yt8m_configs.yt8m(is_training=False) params.global_batch_size = 4 params.segment_labels = False params.input_path = self.data_path params.include_video_id = include_video_id reader = self.create_input_reader(params) dataset = reader.read() iterator = iter(dataset) example = next(iterator) for k, v in example.items(): logging.info('DEBUG read example %r %r %r', k, v.shape, type(v)) if include_video_id: self.assertCountEqual( ['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys()) else: self.assertCountEqual(['video_matrix', 'labels', 'num_frames'], example.keys()) batch_size = params.global_batch_size self.assertEqual( example['video_matrix'].shape.as_list(), [batch_size, params.max_frames, sum(params.feature_sizes)]) self.assertEqual(example['labels'].shape.as_list(), [batch_size, params.num_classes]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) if include_video_id: self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
def test_read_video_level_float_input(self, include_video_id): data_dir = os.path.join(self.get_temp_dir(), 'data2') tf.io.gfile.makedirs(data_dir) data_path = os.path.join(data_dir, 'data2.tfrecord') examples = [ utils.MakeExampleWithFloatFeatures(self.num_segment) for _ in range(8) ] tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples) params = yt8m_configs.yt8m(is_training=False) params.global_batch_size = 4 params.segment_labels = False params.input_path = data_path params.num_frames = 2 params.max_frames = 2 params.feature_names = ('VIDEO_EMBEDDING/context_feature/floats', 'FEATURE/feature/floats') params.feature_sources = ('context', 'feature') params.feature_dtypes = ('float32', 'float32') params.feature_sizes = (256, 2048) params.feature_from_bytes = (False, False) params.include_video_id = include_video_id reader = self.create_input_reader(params) dataset = reader.read() iterator = iter(dataset) example = next(iterator) for k, v in example.items(): logging.info('DEBUG read example %r %r %r', k, v.shape, type(v)) logging.info('DEBUG read example %r', example['video_matrix'][0, 0, :]) if include_video_id: self.assertCountEqual( ['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys()) else: self.assertCountEqual(['video_matrix', 'labels', 'num_frames'], example.keys()) # Check tensor values. expected_context = examples[0].context.feature[ 'VIDEO_EMBEDDING/context_feature/floats'].float_list.value expected_feature = examples[0].feature_lists.feature_list[ 'FEATURE/feature/floats'].feature[0].float_list.value expected_labels = examples[0].context.feature[ params.label_field].int64_list.value self.assertAllEqual( expected_feature, example['video_matrix'][0, 0, params.feature_sizes[0]:]) self.assertAllEqual( expected_context, example['video_matrix'][0, 0, :params.feature_sizes[0]]) self.assertAllEqual( np.nonzero(example['labels'][0, :].numpy())[0], expected_labels) # Check tensor shape. batch_size = params.global_batch_size self.assertEqual( example['video_matrix'].shape.as_list(), [batch_size, params.max_frames, sum(params.feature_sizes)]) self.assertEqual(example['labels'].shape.as_list(), [batch_size, params.num_classes]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) if include_video_id: self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])