def _build_iterator_graph(self,
                              num_epochs,
                              batch_size=1,
                              compression_type=None,
                              buffer_size=None):
        filenames = self._createFiles()
        if compression_type is "ZLIB":
            zlib_files = []
            for i, fn in enumerate(filenames):
                with open(fn, "rb") as f:
                    cdata = zlib.compress(f.read())
                    zfn = os.path.join(self.get_temp_dir(),
                                       "tfrecord_%s.z" % i)
                    with open(zfn, "wb") as f:
                        f.write(cdata)
                    zlib_files.append(zfn)
            filenames = zlib_files

        elif compression_type is "GZIP":
            gzip_files = []
            for i, fn in enumerate(self.test_filenames):
                with open(fn, "rb") as f:
                    gzfn = os.path.join(self.get_temp_dir(),
                                        "tfrecord_%s.gz" % i)
                    with gzip.GzipFile(gzfn, "wb") as gzf:
                        gzf.write(f.read())
                    gzip_files.append(gzfn)
            filenames = gzip_files

        return readers.TFRecordDataset(
            filenames, compression_type,
            buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
Beispiel #2
0
 def testReadWithBuffer(self):
   one_mebibyte = 2**20
   d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
   iterator = d.make_one_shot_iterator()
   with self.test_session() as sess:
     for j in range(self._num_files):
       for i in range(self._num_records):
         self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next()))
     with self.assertRaises(errors.OutOfRangeError):
       sess.run(iterator.get_next())
    def testReadWithEquivalentDataset(self):
        # TODO(mrry): Add support for tf.SparseTensor as a Dataset component.
        features = {
            "file": parsing_ops.FixedLenFeature([], dtypes.int64),
            "record": parsing_ops.FixedLenFeature([], dtypes.int64),
        }
        dataset = (readers.TFRecordDataset(self.test_filenames).map(
            lambda x: parsing_ops.parse_single_example(x, features)).repeat(
                10).batch(2))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        next_element = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for file_batch, _, _, _, record_batch in self._next_expected_batch(
                    range(self._num_files), 2, 10):
                actual_batch = sess.run(next_element)
                self.assertAllEqual(file_batch, actual_batch["file"])
                self.assertAllEqual(record_batch, actual_batch["record"])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
Beispiel #4
0
  def setUp(self):
    super(TFRecordDatasetTestBase, self).setUp()
    self._num_files = 2
    self._num_records = 7

    self.test_filenames = self._createFiles()

    self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
    self.num_epochs = array_ops.placeholder_with_default(
        constant_op.constant(1, dtypes.int64), shape=[])
    self.compression_type = array_ops.placeholder_with_default("", shape=[])
    self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])

    repeat_dataset = readers.TFRecordDataset(self.filenames,
                                             self.compression_type).repeat(
                                                 self.num_epochs)
    batch_dataset = repeat_dataset.batch(self.batch_size)

    iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    self.init_op = iterator.make_initializer(repeat_dataset)
    self.init_batch_op = iterator.make_initializer(batch_dataset)
    self.get_next = iterator.get_next()