def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() path = os.path.join(self.get_temp_dir(), "iterator") dataset = (dataset_ops.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, path) return init_op, get_next_op, save_op, restore_op
def testFixedLengthRecordDatasetBuffering(self): test_filenames = self._createFiles() dataset = dataset_ops.FixedLengthRecordDataset(test_filenames, self._record_bytes, self._header_bytes, self._footer_bytes, buffer_size=10) iterator = dataset.make_one_shot_iterator() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next())
def testFixedLengthRecordDataset(self): test_filenames = self._createFiles() filenames = array_ops.placeholder(dtypes.string, shape=[None]) num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) repeat_dataset = (dataset_ops.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs)) batch_dataset = repeat_dataset.batch(batch_size) iterator = dataset_ops.Iterator.from_structure( batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() with self.test_session() as sess: # Basic test: read from file 0. sess.run(init_op, feed_dict={ filenames: [test_filenames[0]], num_epochs: 1 }) for i in range(self._num_records): self.assertEqual(self._record(0, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from file 1. sess.run(init_op, feed_dict={ filenames: [test_filenames[1]], num_epochs: 1 }) for i in range(self._num_records): self.assertEqual(self._record(1, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from both files. sess.run(init_op, feed_dict={ filenames: test_filenames, num_epochs: 1 }) for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test repeated iteration through both files. sess.run(init_op, feed_dict={ filenames: test_filenames, num_epochs: 10 }) for _ in range(10): for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test batched and repeated iteration through both files. sess.run(init_batch_op, feed_dict={ filenames: test_filenames, num_epochs: 10, batch_size: self._num_records }) for _ in range(10): for j in range(self._num_files): self.assertAllEqual( [self._record(j, i) for i in range(self._num_records)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)