Esempio n. 1
0
    def testTFRecordDataset(self):
        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
            writer = python_io.TFRecordWriter(filename)
            for j in range(_NUM_ENTRIES):
                record = compat.as_bytes('Record %d of file %d' % (j, i))
                writer.write(record)
                all_contents.append(record)
            writer.close()

        dataset = datasets.StreamingFilesDataset(os.path.join(
            self.get_temp_dir(), 'tf_record*'),
                                                 filetype='tfrecord')

        with ops.device(self._worker_device):
            iterator = dataset_ops.make_initializable_iterator(dataset)
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
Esempio n. 2
0
    def testTextLineDataset(self):
        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(),
                                    'text_line.%d.txt' % i)
            contents = []
            for j in range(_NUM_ENTRIES):
                contents.append(compat.as_bytes('%d: %d' % (i, j)))
            with open(filename, 'wb') as f:
                f.write(b'\n'.join(contents))
            all_contents.extend(contents)

        dataset = datasets.StreamingFilesDataset(os.path.join(
            self.get_temp_dir(), 'text_line.*.txt'),
                                                 filetype='text')

        with ops.device(self._worker_device):
            iterator = dataset_ops.make_initializable_iterator(dataset)
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
Esempio n. 3
0
    def testArbitraryReaderFunc(self):
        def MakeRecord(i, j):
            return compat.as_bytes('%04d-%04d' % (i, j))

        record_bytes = len(MakeRecord(10, 200))

        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i)
            with open(filename, 'wb') as f:
                for j in range(_NUM_ENTRIES):
                    record = MakeRecord(i, j)
                    f.write(record)
                    all_contents.append(record)

        def FixedLengthFile(filename):
            return readers.FixedLengthRecordDataset(filename, record_bytes)

        dataset = datasets.StreamingFilesDataset(os.path.join(
            self.get_temp_dir(), 'fixed_length*'),
                                                 filetype=FixedLengthFile)

        with ops.device(self._worker_device):
            iterator = dataset_ops.make_initializable_iterator(dataset)
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
Esempio n. 4
0
    def testArbitraryReaderFuncFromDatasetGenerator(self):
        def my_generator():
            yield (1, [1] * 10)

        def gen_dataset(dummy):
            return dataset_ops.Dataset.from_generator(
                my_generator, (dtypes.int64, dtypes.int64),
                (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))

        dataset = datasets.StreamingFilesDataset(dataset_ops.Dataset.range(10),
                                                 filetype=gen_dataset)

        with ops.device(self._worker_device):
            iterator = dataset_ops.make_initializable_iterator(dataset)
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = self._sess.run(get_next)

        self.assertIsInstance(retrieved_values, (list, tuple))
        self.assertEqual(len(retrieved_values), 2)
        self.assertEqual(retrieved_values[0], 1)
        self.assertItemsEqual(retrieved_values[1], [1] * 10)
Esempio n. 5
0
 def testUnexpectedFilesType(self):
     with self.assertRaises(ValueError):
         datasets.StreamingFilesDataset(123, filetype='tfrecord')
Esempio n. 6
0
 def testUnexpectedFiletypeType(self):
     with self.assertRaises(ValueError):
         datasets.StreamingFilesDataset(os.path.join(
             self.get_temp_dir(), '*'),
                                        filetype=3)