Beispiel #1
0
def write_tfrecords(protos, output_path, options=None):
    """Writes protos to output_path.

  This function writes serialized strings of each proto in protos to output_path
  in their original order. If output_path is a sharded file (e.g., foo@2), this
  function will write the protos spread out as evenly as possible among the
  individual components of the sharded spec (e.g., foo-00000-of-00002 and
  foo-00001-of-00002). Note that the order of records in the sharded files may
  differ from the order in protos due to the striping.

  Args:
    protos: An iterable of protobufs. The objects we want to write out.
    output_path: str. The filepath where we want to write protos.
    options: A python_io.TFRecordOptions object for the writer.
  """
    if not options:
        options = make_tfrecord_options(output_path)

    if sharded_file_utils.is_sharded_file_spec(output_path):
        with contextlib2.ExitStack() as stack:
            _, n_shards, _ = sharded_file_utils.parse_sharded_file_spec(
                output_path)
            writers = [
                stack.enter_context(
                    make_tfrecord_writer(
                        sharded_file_utils.sharded_filename(output_path, i),
                        options)) for i in range(n_shards)
            ]
            for i, proto in enumerate(protos):
                writers[i % n_shards].write(proto.SerializeToString())
    else:
        with make_tfrecord_writer(output_path, options) as writer:
            for proto in protos:
                writer.write(proto.SerializeToString())
Beispiel #2
0
def read_tfrecords(path, proto=None, max_records=None, compression_type=None):
  """Yields the parsed records in a TFRecord file path.

  Note that path can be sharded filespec (path@N) in which case this function
  will read each shard in order; i.e. shard 0 will read each entry in order,
  then shard 1, ...

  Args:
    path: String. A path to a TFRecord file containing protos.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    compression_type: 'GZIP', 'ZLIB', '' (uncompressed), or None to autodetect
      based on file extension.

  Yields:
    proto.FromString() values on each record in path in order.
  """
  if sharded_file_utils.is_sharded_file_spec(path):
    paths = sharded_file_utils.generate_sharded_filenames(path)
  else:
    paths = [path]

  i = 0
  for path in paths:
    for record in Reader(path, proto, compression_type):
      i += 1
      if max_records is not None and i > max_records:
        return
      yield record
Beispiel #3
0
def write_tfrecords(protos, output_path, compression_type=None):
  """Writes protos to output_path.

  This function writes serialized strings of each proto in protos to output_path
  in their original order. If output_path is a sharded file (e.g., foo@2), this
  function will write the protos spread out as evenly as possible among the
  individual components of the sharded spec (e.g., foo-00000-of-00002 and
  foo-00001-of-00002). Note that the order of records in the sharded files may
  differ from the order in protos due to the striping.

  Args:
    protos: An iterable of protobufs. The objects we want to write out.
    output_path: str. The filepath where we want to write protos.
    compression_type: 'GZIP', 'ZLIB', '' (uncompressed), or None to autodetect
      based on file extension.
  """
  if sharded_file_utils.is_sharded_file_spec(output_path):
    with contextlib2.ExitStack() as stack:
      _, n_shards, _ = sharded_file_utils.parse_sharded_file_spec(output_path)
      writers = [
          stack.enter_context(
              Writer(sharded_file_utils.sharded_filename(
                  output_path, i), compression_type))
          for i in range(n_shards)
      ]
      for i, proto in enumerate(protos):
        writers[i % n_shards].write(proto)
  else:
    with Writer(output_path, compression_type=compression_type) as writer:
      for proto in protos:
        writer.write(proto)
Beispiel #4
0
def read_shard_sorted_tfrecords(path,
                                key,
                                proto=None,
                                max_records=None,
                                options=None):
    """Yields the parsed records in a TFRecord file path in sorted order.

  The input TFRecord file must have each shard already in sorted order when
  using the key function for comparison (but elements can be interleaved across
  shards). Under those constraints, the elements will be yielded in a global
  sorted order.

  Args:
    path: String. A path to a TFRecord-formatted file containing protos.
    key: Callable. A function that takes as input a single instance of the proto
      class and returns a value on which the comparison for sorted ordering is
      performed.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    options: A python_io.TFRecordOptions object for the reader.

  Yields:
    proto.FromString() values on each record in path in sorted order.
  """
    if proto is None:
        proto = example_pb2.Example

    if options is None:
        options = make_tfrecord_options(path)

    if sharded_file_utils.is_sharded_file_spec(path):
        paths = sharded_file_utils.generate_sharded_filenames(path)
    else:
        paths = [path]

    keyed_iterables = []
    for path in paths:
        protos = (proto.FromString(buf)
                  for buf in python_io.tf_record_iterator(path, options))
        keyed_iterables.append(((key(elem), elem) for elem in protos))

    for i, (_, value) in enumerate(heapq.merge(*keyed_iterables)):
        if max_records is not None and i >= max_records:
            return
        yield value
Beispiel #5
0
def read_shard_sorted_tfrecords(path,
                                key,
                                proto=None,
                                max_records=None,
                                compression_type=None):
    """Yields the parsed records in a TFRecord file path in sorted order.

  The input TFRecord file must have each shard already in sorted order when
  using the key function for comparison (but elements can be interleaved across
  shards). Under those constraints, the elements will be yielded in a global
  sorted order.

  Args:
    path: String. A path to a TFRecord-formatted file containing protos.
    key: Callable. A function that takes as input a single instance of the proto
      class and returns a value on which the comparison for sorted ordering is
      performed.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    compression_type: 'GZIP', 'ZLIB', '' (uncompressed), or None to autodetect
      based on file extension.

  Yields:
    proto.FromString() values on each record in path in sorted order.
  """
    if sharded_file_utils.is_sharded_file_spec(path):
        paths = sharded_file_utils.generate_sharded_filenames(path)
    else:
        paths = [path]

    keyed_iterables = []
    for path in paths:
        protos = Reader(path, proto,
                        compression_type=compression_type).iterate()
        keyed_iterables.append(((key(elem), elem) for elem in protos))

    for i, (_, value) in enumerate(heapq.merge(*keyed_iterables)):
        if max_records is not None and i >= max_records:
            return
        yield value
Beispiel #6
0
def read_tfrecords(path, proto=None, max_records=None, options=None):
    """Yields the parsed records in a TFRecord file path.

  Note that path can be sharded filespec (path@N) in which case this function
  will read each shard in order; i.e. shard 0 will read each entry in order,
  then shard 1, ...

  Args:
    path: String. A path to a TFRecord file containing protos.
    proto: A proto class. proto.FromString() will be called on each serialized
      record in path to parse it.
    max_records: int >= 0 or None. Maximum number of records to read from path.
      If None, the default, all records will be read.
    options: A python_io.TFRecordOptions object for the reader.

  Yields:
    proto.FromString() values on each record in path in order.
  """
    if not proto:
        proto = example_pb2.Example

    if not options:
        options = make_tfrecord_options(path)

    if sharded_file_utils.is_sharded_file_spec(path):
        paths = sharded_file_utils.generate_sharded_filenames(path)
    else:
        paths = [path]

    i = 0
    for path in paths:
        for buf in python_io.tf_record_iterator(path, options):
            i += 1
            if max_records is not None and i > max_records:
                return
            yield proto.FromString(buf)
 def testIsShardedFileSpec(self, spec, expected):
   actual = io.is_sharded_file_spec(spec)
   self.assertEqual(actual, expected,
                     'io.IshShardedFileSpec({0}) is {1} expected {2}'.format(
                         spec, actual, expected))