Ejemplo n.º 1
0
    def test_read_video_level_input(self, include_video_id):
        params = yt8m_configs.yt8m(is_training=False)
        params.global_batch_size = 4
        params.segment_labels = False
        params.input_path = self.data_path
        params.include_video_id = include_video_id
        reader = self.create_input_reader(params)

        dataset = reader.read()
        iterator = iter(dataset)
        example = next(iterator)

        for k, v in example.items():
            logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
        if include_video_id:
            self.assertCountEqual(
                ['video_matrix', 'labels', 'num_frames', 'video_ids'],
                example.keys())
        else:
            self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
                                  example.keys())
        batch_size = params.global_batch_size
        self.assertEqual(
            example['video_matrix'].shape.as_list(),
            [batch_size, params.max_frames,
             sum(params.feature_sizes)])
        self.assertEqual(example['labels'].shape.as_list(),
                         [batch_size, params.num_classes])
        self.assertEqual(example['num_frames'].shape.as_list(),
                         [batch_size, 1])
        if include_video_id:
            self.assertEqual(example['video_ids'].shape.as_list(),
                             [batch_size, 1])
Ejemplo n.º 2
0
    def test_read_video_level_float_input(self, include_video_id):
        data_dir = os.path.join(self.get_temp_dir(), 'data2')
        tf.io.gfile.makedirs(data_dir)
        data_path = os.path.join(data_dir, 'data2.tfrecord')
        examples = [
            utils.MakeExampleWithFloatFeatures(self.num_segment)
            for _ in range(8)
        ]
        tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples)

        params = yt8m_configs.yt8m(is_training=False)
        params.global_batch_size = 4
        params.segment_labels = False
        params.input_path = data_path
        params.num_frames = 2
        params.max_frames = 2
        params.feature_names = ('VIDEO_EMBEDDING/context_feature/floats',
                                'FEATURE/feature/floats')
        params.feature_sources = ('context', 'feature')
        params.feature_dtypes = ('float32', 'float32')
        params.feature_sizes = (256, 2048)
        params.feature_from_bytes = (False, False)
        params.include_video_id = include_video_id
        reader = self.create_input_reader(params)

        dataset = reader.read()
        iterator = iter(dataset)
        example = next(iterator)

        for k, v in example.items():
            logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
        logging.info('DEBUG read example %r', example['video_matrix'][0, 0, :])
        if include_video_id:
            self.assertCountEqual(
                ['video_matrix', 'labels', 'num_frames', 'video_ids'],
                example.keys())
        else:
            self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
                                  example.keys())

        # Check tensor values.
        expected_context = examples[0].context.feature[
            'VIDEO_EMBEDDING/context_feature/floats'].float_list.value
        expected_feature = examples[0].feature_lists.feature_list[
            'FEATURE/feature/floats'].feature[0].float_list.value
        expected_labels = examples[0].context.feature[
            params.label_field].int64_list.value
        self.assertAllEqual(
            expected_feature,
            example['video_matrix'][0, 0, params.feature_sizes[0]:])
        self.assertAllEqual(
            expected_context,
            example['video_matrix'][0, 0, :params.feature_sizes[0]])
        self.assertAllEqual(
            np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)

        # Check tensor shape.
        batch_size = params.global_batch_size
        self.assertEqual(
            example['video_matrix'].shape.as_list(),
            [batch_size, params.max_frames,
             sum(params.feature_sizes)])
        self.assertEqual(example['labels'].shape.as_list(),
                         [batch_size, params.num_classes])
        self.assertEqual(example['num_frames'].shape.as_list(),
                         [batch_size, 1])
        if include_video_id:
            self.assertEqual(example['video_ids'].shape.as_list(),
                             [batch_size, 1])