def _build_iterator_graph(self, num_epochs, batch_size=1, compression_type=None, buffer_size=None): filenames = self._createFiles() if compression_type is "ZLIB": zlib_files = [] for i, fn in enumerate(filenames): with open(fn, "rb") as f: cdata = zlib.compress(f.read()) zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) with open(zfn, "wb") as f: f.write(cdata) zlib_files.append(zfn) filenames = zlib_files elif compression_type is "GZIP": gzip_files = [] for i, fn in enumerate(self.test_filenames): with open(fn, "rb") as f: gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) with gzip.GzipFile(gzfn, "wb") as gzf: gzf.write(f.read()) gzip_files.append(gzfn) filenames = gzip_files return readers.TFRecordDataset( filenames, compression_type, buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
def testReadWithBuffer(self): one_mebibyte = 2**20 d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next())
def testReadWithEquivalentDataset(self): # TODO(mrry): Add support for tf.SparseTensor as a Dataset component. features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } dataset = (readers.TFRecordDataset(self.test_filenames).map( lambda x: parsing_ops.parse_single_example(x, features)).repeat( 10).batch(2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer next_element = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch in self._next_expected_batch( range(self._num_files), 2, 10): actual_batch = sess.run(next_element) self.assertAllEqual(file_batch, actual_batch["file"]) self.assertAllEqual(record_batch, actual_batch["record"]) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def setUp(self): super(TFRecordDatasetTestBase, self).setUp() self._num_files = 2 self._num_records = 7 self.test_filenames = self._createFiles() self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) self.num_epochs = array_ops.placeholder_with_default( constant_op.constant(1, dtypes.int64), shape=[]) self.compression_type = array_ops.placeholder_with_default("", shape=[]) self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) repeat_dataset = readers.TFRecordDataset(self.filenames, self.compression_type).repeat( self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) self.init_op = iterator.make_initializer(repeat_dataset) self.init_batch_op = iterator.make_initializer(batch_dataset) self.get_next = iterator.get_next()