예제 #1
0
    def testZLibFlushRecord(self):
        fn = self._WriteRecordsToFile([b"small record"], "small_record")
        with open(fn, "rb") as h:
            buff = h.read()

        # creating more blocks and trailing blocks shouldn't break reads
        compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)

        output = b""
        for c in buff:
            if isinstance(c, int):
                c = six.int2byte(c)
            output += compressor.compress(c)
            output += compressor.flush(zlib.Z_FULL_FLUSH)

        output += compressor.flush(zlib.Z_FULL_FLUSH)
        output += compressor.flush(zlib.Z_FULL_FLUSH)
        output += compressor.flush(zlib.Z_FINISH)

        # overwrite the original file with the compressed data
        with open(fn, "wb") as h:
            h.write(output)

        with self.test_session() as sess:
            options = tf_record.TFRecordOptions(
                compression_type=TFRecordCompressionType.ZLIB)
            reader = io_ops.TFRecordReader(name="test_reader", options=options)
            queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
            key, value = reader.read(queue)
            queue.enqueue(fn).run()
            queue.close().run()
            k, v = sess.run([key, value])
            self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
            self.assertAllEqual(b"small record", v)
예제 #2
0
    def testReadUpTo(self):
        files = self._CreateFiles()
        with self.cached_session() as sess:
            reader = io_ops.TFRecordReader(name="test_reader")
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            batch_size = 3
            key, value = reader.read_up_to(queue, batch_size)

            queue.enqueue_many([files]).run()
            queue.close().run()
            num_k = 0
            num_v = 0

            while True:
                try:
                    k, v = sess.run([key, value])
                    # Test reading *up to* batch_size records
                    self.assertLessEqual(len(k), batch_size)
                    self.assertLessEqual(len(v), batch_size)
                    num_k += len(k)
                    num_v += len(v)
                except errors_impl.OutOfRangeError:
                    break

            # Test that we have read everything
            self.assertEqual(self._num_files * self._num_records, num_k)
            self.assertEqual(self._num_files * self._num_records, num_v)
예제 #3
0
    def testReadGzipFiles(self):
        files = self._CreateFiles()
        gzip_files = []
        for i, fn in enumerate(files):
            with open(fn, "rb") as f:
                cdata = f.read()

                zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
                with gzip.GzipFile(zfn, "wb") as f:
                    f.write(cdata)
                gzip_files.append(zfn)

        with self.test_session() as sess:
            options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
            reader = io_ops.TFRecordReader(name="test_reader", options=options)
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            key, value = reader.read(queue)

            queue.enqueue_many([gzip_files]).run()
            queue.close().run()
            for i in range(self._num_files):
                for j in range(self._num_records):
                    k, v = sess.run([key, value])
                    self.assertTrue(
                        compat.as_text(k).startswith("%s:" % gzip_files[i]))
                    self.assertAllEqual(self._Record(i, j), v)
예제 #4
0
    def testReadGzipFiles(self):
        options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
        files = self._CreateFiles(options)

        reader = io_ops.TFRecordReader(name="test_reader", options=options)
        queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
        key, value = reader.read(queue)

        self.evaluate(queue.enqueue_many([files]))
        self.evaluate(queue.close())
        for i in range(self._num_files):
            for j in range(self._num_records):
                k, v = self.evaluate([key, value])
                self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
                self.assertAllEqual(self._Record(i, j), v)
예제 #5
0
  def testOneEpoch(self):
    files = self._CreateFiles()
    reader = io_ops.TFRecordReader(name="test_reader")
    queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    key, value = reader.read(queue)

    self.evaluate(queue.enqueue_many([files]))
    self.evaluate(queue.close())
    for i in range(self._num_files):
      for j in range(self._num_records):
        k, v = self.evaluate([key, value])
        self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
        self.assertAllEqual(self._Record(i, j), v)

    with self.assertRaisesOpError("is closed and has insufficient elements "
                                  "\\(requested 1, current size 0\\)"):
      k, v = self.evaluate([key, value])
예제 #6
0
    def testOneEpoch(self):
        options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
        files = self._CreateFiles(options)
        with self.test_session() as sess:
            reader = io_ops.TFRecordReader(name="test_reader", options=options)
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            key, value = reader.read(queue)

            queue.enqueue_many([files]).run()
            queue.close().run()
            for i in range(self._num_files):
                for j in range(self._num_records):
                    k, v = sess.run([key, value])
                    self.assertTrue(
                        compat.as_text(k).startswith("%s:" % files[i]))
                    self.assertAllEqual(self._Record(i, j), v)

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                k, v = sess.run([key, value])
 def _get_reader(self):
     return io_ops.TFRecordReader()