예제 #1
0
    def testOutOfRangeError(self):
        with self.cached_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.cached_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                num_reads = 11
                with self.assertRaises(errors_impl.OutOfRangeError):
                    for _ in range(num_reads):
                        sess.run([key, value])
예제 #2
0
    def testTFRecordSeparateGetDataset(self):
        dataset_dir = tempfile.mkdtemp('tfrecord_separate_get')

        height = 300
        width = 280

        with self.cached_session():
            provider = dataset_data_provider.DatasetDataProvider(
                _create_tfrecord_dataset(dataset_dir))
        [image] = provider.get(['image'])
        [label] = provider.get(['label'])
        image = _resize_image(image, height, width)

        with session.Session('') as sess:
            with queues.QueueRunners(sess):
                image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
예제 #3
0
    def testTFRecordReader(self):
        with self.cached_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.cached_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                flowers = 0
                num_reads = 9
                for _ in range(num_reads):
                    current_key, _ = sess.run([key, value])
                    if 'flowers' in str(current_key):
                        flowers += 1
                self.assertGreater(flowers, 0)
                self.assertEqual(flowers, num_reads)
예제 #4
0
    def testTFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp('tfrecord_dataset')

        height = 300
        width = 280

        with self.cached_session():
            test_dataset = _create_tfrecord_dataset(dataset_dir)
            provider = dataset_data_provider.DatasetDataProvider(test_dataset)
            key, image, label = provider.get(['record_key', 'image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    key, image, label = sess.run([key, image, label])
            split_key = key.decode('utf-8').split(':')
            self.assertEqual(2, len(split_key))
            self.assertEqual(test_dataset.data_sources[0], split_key[0])
            self.assertTrue(split_key[1].isdigit())
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))