def load_dataset_metadata(metadata_filename): """Helper function to load dataset metadata. Args: metadata_filename: Filename containing dataset metadata. Returns: Padding configuration and edge types for the dataset. """ with gfile.GFile(metadata_filename, "r") as fp: metadata = json.load(fp) edge_types = metadata["edge_types"] padding_config = flax.serialization.from_state_dict( target=jax_util.synthesize_dataclass(graph_bundle.PaddingConfig), state=metadata["padding_config"]) return padding_config, edge_types
def build_sampling_iterator( tfrecord_path, example_type, num_parallel_reads=16, shuffle_buffer=2048, truncate_at=None, ): """Build a sampling dataset iterator for individual examples. Args: tfrecord_path: Path to the TFRecord files to use. Can include a * glob pattern to load multiple files. example_type: Dataclass to use to deserialize the results. num_parallel_reads: How many files to read from at the same time. shuffle_buffer: How many examples to store in the shuffle buffer (after interleaving chunks). truncate_at: How many examples to produce. Yields: train_util.ExampleWithMetadata objects, where epoch starts at 0 and increments every time we make a full pass through the dataset. No batching is performed. """ if truncate_at is not None and num_parallel_reads is not None: # Can't guarantee iteration order when truncating logging.warning( "Disabling num_parallel_reads due to truncated dataset.") num_parallel_reads = None dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path), num_parallel_reads=num_parallel_reads) if truncate_at: dataset = dataset.take(truncate_at) dataset = dataset.shuffle(shuffle_buffer) prototype_object = (0, jax_util.synthesize_dataclass(example_type)) for epoch in itertools.count(): for item in dataset.as_numpy_iterator(): ex_id, ex = flax.serialization.from_bytes(target=prototype_object, encoded_bytes=item) yield train_util.ExampleWithMetadata(epoch, ex_id, ex)
def test_synthesize_dataclass(self): @dataclasses.dataclass class Inner: x: jax_util.NDArray y: int z: Any @dataclasses.dataclass class Outer: a: str b: Inner synthesized = jax_util.synthesize_dataclass(Outer) self.assertEqual( synthesized, Outer(a="", b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray), y=0, z=jax_util.LeafPlaceholder(Any)))) # type:ignore
def test_synthesize_dataclass(self): @dataclasses.dataclass class Inner: x: jax_util.NDArray y: int z: Any @dataclasses.dataclass class Outer: a: str b: Inner # pytype: disable=invalid-annotation # enable-bare-annotations synthesized = jax_util.synthesize_dataclass(Outer) self.assertEqual( synthesized, Outer(a="", b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray), y=0, z=jax_util.LeafPlaceholder(Any)))) # type:ignore
def build_one_pass_iterator_factory( tfrecord_path, example_type, truncate_at=None, skip_first=0, ): """Build a deterministic one-epoch iterator for unbatched examples. Args: tfrecord_path: Path to the TFRecord files to use. Can include a * glob pattern to load multiple files. example_type: Dataclass to use to deserialize the results. truncate_at: Number of examples to truncate the table at. Determines the effective size of the dataset. skip_first: Number of examples to skip at the beginnning of the dataset. Returns: Callable with no args that, when called, returns a new dataset iterator. This iterator produces train_util.ExampleWithMetadata objects, where epoch gives the number of times the factory function has been called. Each returned iterator will make exactly one pass through the dataset and then stop. """ dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path)) dataset = dataset.skip(skip_first) if truncate_at: dataset = dataset.take(truncate_at) prototype_object = (0, jax_util.synthesize_dataclass(example_type)) epoch = 0 def factory(): nonlocal epoch epoch += 1 for item in dataset.as_numpy_iterator(): ex_id, ex = flax.serialization.from_bytes(target=prototype_object, encoded_bytes=item) yield train_util.ExampleWithMetadata(epoch, ex_id, ex) return factory