示例#1
0
 def factory():
     nonlocal epoch
     epoch += 1
     for item in dataset.as_numpy_iterator():
         ex_id, ex = flax.serialization.from_bytes(target=prototype_object,
                                                   encoded_bytes=item)
         yield train_util.ExampleWithMetadata(epoch, ex_id, ex)
示例#2
0
def build_sampling_iterator(
    tfrecord_path,
    example_type,
    num_parallel_reads=16,
    shuffle_buffer=2048,
    truncate_at=None,
):
    """Build a sampling dataset iterator for individual examples.

  Args:
    tfrecord_path: Path to the TFRecord files to use. Can include a * glob
      pattern to load multiple files.
    example_type: Dataclass to use to deserialize the results.
    num_parallel_reads: How many files to read from at the same time.
    shuffle_buffer: How many examples to store in the shuffle buffer (after
      interleaving chunks).
    truncate_at: How many examples to produce.

  Yields:
    train_util.ExampleWithMetadata objects, where epoch starts at 0
    and increments every time we make a full pass through the dataset. No
    batching is performed.
  """
    if truncate_at is not None and num_parallel_reads is not None:
        # Can't guarantee iteration order when truncating
        logging.warning(
            "Disabling num_parallel_reads due to truncated dataset.")
        num_parallel_reads = None

    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path),
                                      num_parallel_reads=num_parallel_reads)
    if truncate_at:
        dataset = dataset.take(truncate_at)
    dataset = dataset.shuffle(shuffle_buffer)

    prototype_object = (0, jax_util.synthesize_dataclass(example_type))

    for epoch in itertools.count():
        for item in dataset.as_numpy_iterator():
            ex_id, ex = flax.serialization.from_bytes(target=prototype_object,
                                                      encoded_bytes=item)
            yield train_util.ExampleWithMetadata(epoch, ex_id, ex)