コード例 #1
0
 def _build_iterator_graph(self, num_epochs):
     filenames = self._createFiles()
     path = os.path.join(self.get_temp_dir(), "iterator")
     dataset = (dataset_ops.FixedLengthRecordDataset(
         filenames, self._record_bytes, self._header_bytes,
         self._footer_bytes).repeat(num_epochs))
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next_op = iterator.get_next()
     save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource,
                                             path)
     restore_op = gen_dataset_ops.restore_iterator(
         iterator._iterator_resource, path)
     return init_op, get_next_op, save_op, restore_op
コード例 #2
0
    def testFixedLengthRecordDatasetBuffering(self):
        test_filenames = self._createFiles()
        dataset = dataset_ops.FixedLengthRecordDataset(test_filenames,
                                                       self._record_bytes,
                                                       self._header_bytes,
                                                       self._footer_bytes,
                                                       buffer_size=10)
        iterator = dataset.make_one_shot_iterator()

        with self.test_session() as sess:
            for j in range(self._num_files):
                for i in range(self._num_records):
                    self.assertEqual(self._record(j, i),
                                     sess.run(iterator.get_next()))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(iterator.get_next())
コード例 #3
0
    def testFixedLengthRecordDataset(self):
        test_filenames = self._createFiles()
        filenames = array_ops.placeholder(dtypes.string, shape=[None])
        num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
        batch_size = array_ops.placeholder(dtypes.int64, shape=[])

        repeat_dataset = (dataset_ops.FixedLengthRecordDataset(
            filenames, self._record_bytes, self._header_bytes,
            self._footer_bytes).repeat(num_epochs))
        batch_dataset = repeat_dataset.batch(batch_size)

        iterator = dataset_ops.Iterator.from_structure(
            batch_dataset.output_types)
        init_op = iterator.make_initializer(repeat_dataset)
        init_batch_op = iterator.make_initializer(batch_dataset)
        get_next = iterator.get_next()

        with self.test_session() as sess:
            # Basic test: read from file 0.
            sess.run(init_op,
                     feed_dict={
                         filenames: [test_filenames[0]],
                         num_epochs: 1
                     })
            for i in range(self._num_records):
                self.assertEqual(self._record(0, i), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Basic test: read from file 1.
            sess.run(init_op,
                     feed_dict={
                         filenames: [test_filenames[1]],
                         num_epochs: 1
                     })
            for i in range(self._num_records):
                self.assertEqual(self._record(1, i), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Basic test: read from both files.
            sess.run(init_op,
                     feed_dict={
                         filenames: test_filenames,
                         num_epochs: 1
                     })
            for j in range(self._num_files):
                for i in range(self._num_records):
                    self.assertEqual(self._record(j, i), sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Test repeated iteration through both files.
            sess.run(init_op,
                     feed_dict={
                         filenames: test_filenames,
                         num_epochs: 10
                     })
            for _ in range(10):
                for j in range(self._num_files):
                    for i in range(self._num_records):
                        self.assertEqual(self._record(j, i),
                                         sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Test batched and repeated iteration through both files.
            sess.run(init_batch_op,
                     feed_dict={
                         filenames: test_filenames,
                         num_epochs: 10,
                         batch_size: self._num_records
                     })
            for _ in range(10):
                for j in range(self._num_files):
                    self.assertAllEqual(
                        [self._record(j, i) for i in range(self._num_records)],
                        sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)