def factory(): nonlocal epoch epoch += 1 for item in dataset.as_numpy_iterator(): ex_id, ex = flax.serialization.from_bytes(target=prototype_object, encoded_bytes=item) yield train_util.ExampleWithMetadata(epoch, ex_id, ex)
def build_sampling_iterator( tfrecord_path, example_type, num_parallel_reads=16, shuffle_buffer=2048, truncate_at=None, ): """Build a sampling dataset iterator for individual examples. Args: tfrecord_path: Path to the TFRecord files to use. Can include a * glob pattern to load multiple files. example_type: Dataclass to use to deserialize the results. num_parallel_reads: How many files to read from at the same time. shuffle_buffer: How many examples to store in the shuffle buffer (after interleaving chunks). truncate_at: How many examples to produce. Yields: train_util.ExampleWithMetadata objects, where epoch starts at 0 and increments every time we make a full pass through the dataset. No batching is performed. """ if truncate_at is not None and num_parallel_reads is not None: # Can't guarantee iteration order when truncating logging.warning( "Disabling num_parallel_reads due to truncated dataset.") num_parallel_reads = None dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path), num_parallel_reads=num_parallel_reads) if truncate_at: dataset = dataset.take(truncate_at) dataset = dataset.shuffle(shuffle_buffer) prototype_object = (0, jax_util.synthesize_dataclass(example_type)) for epoch in itertools.count(): for item in dataset.as_numpy_iterator(): ex_id, ex = flax.serialization.from_bytes(target=prototype_object, encoded_bytes=item) yield train_util.ExampleWithMetadata(epoch, ex_id, ex)