def _create_test_tfrecord(self, test_tfrecord_file, num_samples, input_image_size): example = tf.train.Example.FromString( tfexample_utils.create_classification_example( image_height=input_image_size[0], image_width=input_image_size[1])) examples = [example] * num_samples tfexample_utils.dump_to_tfrecord(record_file=test_tfrecord_file, tf_examples=examples)
def setUp(self): super().setUp() self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir') tf.io.gfile.makedirs(self._model_dir) data_dir = os.path.join(self.get_temp_dir(), 'data') tf.io.gfile.makedirs(data_dir) self._data_path = os.path.join(data_dir, 'data.tfrecord') examples = [utils.make_yt8m_example() for _ in range(8)] tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def setUp(self): super(VideoClassificationTaskTest, self).setUp() data_dir = os.path.join(self.get_temp_dir(), 'data') tf.io.gfile.makedirs(data_dir) self._data_path = os.path.join(data_dir, 'data.tfrecord') # pylint: disable=g-complex-comprehension examples = [ tfexample_utils.make_video_test_example(image_shape=(36, 36, 3), audio_shape=(20, 128), label=random.randint( 0, 100)) for _ in range(2) ] # pylint: enable=g-complex-comprehension tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def setUp(self): super().setUp() data_dir = os.path.join(self.get_temp_dir(), 'data') tf.io.gfile.makedirs(data_dir) self._data_path = os.path.join(data_dir, 'data.tfrecord') # pylint: disable=g-complex-comprehension examples = [ tfexample_utils.create_3d_image_test_example(image_height=32, image_width=32, image_volume=32, image_channel=2) for _ in range(20) ] # pylint: enable=g-complex-comprehension tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_read_video_level_float_input(self, include_video_id): data_dir = os.path.join(self.get_temp_dir(), 'data2') tf.io.gfile.makedirs(data_dir) data_path = os.path.join(data_dir, 'data2.tfrecord') examples = [ utils.MakeExampleWithFloatFeatures(self.num_segment) for _ in range(8) ] tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples) params = yt8m_configs.yt8m(is_training=False) params.global_batch_size = 4 params.segment_labels = False params.input_path = data_path params.num_frames = 2 params.max_frames = 2 params.feature_names = ('VIDEO_EMBEDDING/context_feature/floats', 'FEATURE/feature/floats') params.feature_sources = ('context', 'feature') params.feature_dtypes = ('float32', 'float32') params.feature_sizes = (256, 2048) params.feature_from_bytes = (False, False) params.include_video_id = include_video_id reader = self.create_input_reader(params) dataset = reader.read() iterator = iter(dataset) example = next(iterator) for k, v in example.items(): logging.info('DEBUG read example %r %r %r', k, v.shape, type(v)) logging.info('DEBUG read example %r', example['video_matrix'][0, 0, :]) if include_video_id: self.assertCountEqual( ['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys()) else: self.assertCountEqual(['video_matrix', 'labels', 'num_frames'], example.keys()) # Check tensor values. expected_context = examples[0].context.feature[ 'VIDEO_EMBEDDING/context_feature/floats'].float_list.value expected_feature = examples[0].feature_lists.feature_list[ 'FEATURE/feature/floats'].feature[0].float_list.value expected_labels = examples[0].context.feature[ params.label_field].int64_list.value self.assertAllEqual( expected_feature, example['video_matrix'][0, 0, params.feature_sizes[0]:]) self.assertAllEqual( expected_context, example['video_matrix'][0, 0, :params.feature_sizes[0]]) self.assertAllEqual( np.nonzero(example['labels'][0, :].numpy())[0], expected_labels) # Check tensor shape. batch_size = params.global_batch_size self.assertEqual( example['video_matrix'].shape.as_list(), [batch_size, params.max_frames, sum(params.feature_sizes)]) self.assertEqual(example['labels'].shape.as_list(), [batch_size, params.num_classes]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) if include_video_id: self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
def _create_test_tfrecord(self, tfrecord_file, example, num_samples): examples = [example] * num_samples tfexample_utils.dump_to_tfrecord( record_file=tfrecord_file, tf_examples=examples)