コード例 #1
0
 def _create_test_tfrecord(self, test_tfrecord_file, num_samples,
                           input_image_size):
     example = tf.train.Example.FromString(
         tfexample_utils.create_classification_example(
             image_height=input_image_size[0],
             image_width=input_image_size[1]))
     examples = [example] * num_samples
     tfexample_utils.dump_to_tfrecord(record_file=test_tfrecord_file,
                                      tf_examples=examples)
コード例 #2
0
    def setUp(self):
        super().setUp()
        self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
        tf.io.gfile.makedirs(self._model_dir)

        data_dir = os.path.join(self.get_temp_dir(), 'data')
        tf.io.gfile.makedirs(data_dir)
        self._data_path = os.path.join(data_dir, 'data.tfrecord')
        examples = [utils.make_yt8m_example() for _ in range(8)]
        tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
コード例 #3
0
 def setUp(self):
     super(VideoClassificationTaskTest, self).setUp()
     data_dir = os.path.join(self.get_temp_dir(), 'data')
     tf.io.gfile.makedirs(data_dir)
     self._data_path = os.path.join(data_dir, 'data.tfrecord')
     # pylint: disable=g-complex-comprehension
     examples = [
         tfexample_utils.make_video_test_example(image_shape=(36, 36, 3),
                                                 audio_shape=(20, 128),
                                                 label=random.randint(
                                                     0, 100))
         for _ in range(2)
     ]
     # pylint: enable=g-complex-comprehension
     tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
コード例 #4
0
 def setUp(self):
     super().setUp()
     data_dir = os.path.join(self.get_temp_dir(), 'data')
     tf.io.gfile.makedirs(data_dir)
     self._data_path = os.path.join(data_dir, 'data.tfrecord')
     # pylint: disable=g-complex-comprehension
     examples = [
         tfexample_utils.create_3d_image_test_example(image_height=32,
                                                      image_width=32,
                                                      image_volume=32,
                                                      image_channel=2)
         for _ in range(20)
     ]
     # pylint: enable=g-complex-comprehension
     tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
コード例 #5
0
    def test_read_video_level_float_input(self, include_video_id):
        data_dir = os.path.join(self.get_temp_dir(), 'data2')
        tf.io.gfile.makedirs(data_dir)
        data_path = os.path.join(data_dir, 'data2.tfrecord')
        examples = [
            utils.MakeExampleWithFloatFeatures(self.num_segment)
            for _ in range(8)
        ]
        tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples)

        params = yt8m_configs.yt8m(is_training=False)
        params.global_batch_size = 4
        params.segment_labels = False
        params.input_path = data_path
        params.num_frames = 2
        params.max_frames = 2
        params.feature_names = ('VIDEO_EMBEDDING/context_feature/floats',
                                'FEATURE/feature/floats')
        params.feature_sources = ('context', 'feature')
        params.feature_dtypes = ('float32', 'float32')
        params.feature_sizes = (256, 2048)
        params.feature_from_bytes = (False, False)
        params.include_video_id = include_video_id
        reader = self.create_input_reader(params)

        dataset = reader.read()
        iterator = iter(dataset)
        example = next(iterator)

        for k, v in example.items():
            logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
        logging.info('DEBUG read example %r', example['video_matrix'][0, 0, :])
        if include_video_id:
            self.assertCountEqual(
                ['video_matrix', 'labels', 'num_frames', 'video_ids'],
                example.keys())
        else:
            self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
                                  example.keys())

        # Check tensor values.
        expected_context = examples[0].context.feature[
            'VIDEO_EMBEDDING/context_feature/floats'].float_list.value
        expected_feature = examples[0].feature_lists.feature_list[
            'FEATURE/feature/floats'].feature[0].float_list.value
        expected_labels = examples[0].context.feature[
            params.label_field].int64_list.value
        self.assertAllEqual(
            expected_feature,
            example['video_matrix'][0, 0, params.feature_sizes[0]:])
        self.assertAllEqual(
            expected_context,
            example['video_matrix'][0, 0, :params.feature_sizes[0]])
        self.assertAllEqual(
            np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)

        # Check tensor shape.
        batch_size = params.global_batch_size
        self.assertEqual(
            example['video_matrix'].shape.as_list(),
            [batch_size, params.max_frames,
             sum(params.feature_sizes)])
        self.assertEqual(example['labels'].shape.as_list(),
                         [batch_size, params.num_classes])
        self.assertEqual(example['num_frames'].shape.as_list(),
                         [batch_size, 1])
        if include_video_id:
            self.assertEqual(example['video_ids'].shape.as_list(),
                             [batch_size, 1])
コード例 #6
0
 def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
   examples = [example] * num_samples
   tfexample_utils.dump_to_tfrecord(
       record_file=tfrecord_file, tf_examples=examples)