def load_dataset_metadata(metadata_filename):
    """Helper function to load dataset metadata.

  Args:
    metadata_filename: Filename containing dataset metadata.

  Returns:
    Padding configuration and edge types for the dataset.
  """
    with gfile.GFile(metadata_filename, "r") as fp:
        metadata = json.load(fp)

    edge_types = metadata["edge_types"]
    padding_config = flax.serialization.from_state_dict(
        target=jax_util.synthesize_dataclass(graph_bundle.PaddingConfig),
        state=metadata["padding_config"])
    return padding_config, edge_types
Пример #2
0
def build_sampling_iterator(
    tfrecord_path,
    example_type,
    num_parallel_reads=16,
    shuffle_buffer=2048,
    truncate_at=None,
):
    """Build a sampling dataset iterator for individual examples.

  Args:
    tfrecord_path: Path to the TFRecord files to use. Can include a * glob
      pattern to load multiple files.
    example_type: Dataclass to use to deserialize the results.
    num_parallel_reads: How many files to read from at the same time.
    shuffle_buffer: How many examples to store in the shuffle buffer (after
      interleaving chunks).
    truncate_at: How many examples to produce.

  Yields:
    train_util.ExampleWithMetadata objects, where epoch starts at 0
    and increments every time we make a full pass through the dataset. No
    batching is performed.
  """
    if truncate_at is not None and num_parallel_reads is not None:
        # Can't guarantee iteration order when truncating
        logging.warning(
            "Disabling num_parallel_reads due to truncated dataset.")
        num_parallel_reads = None

    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path),
                                      num_parallel_reads=num_parallel_reads)
    if truncate_at:
        dataset = dataset.take(truncate_at)
    dataset = dataset.shuffle(shuffle_buffer)

    prototype_object = (0, jax_util.synthesize_dataclass(example_type))

    for epoch in itertools.count():
        for item in dataset.as_numpy_iterator():
            ex_id, ex = flax.serialization.from_bytes(target=prototype_object,
                                                      encoded_bytes=item)
            yield train_util.ExampleWithMetadata(epoch, ex_id, ex)
    def test_synthesize_dataclass(self):
        @dataclasses.dataclass
        class Inner:
            x: jax_util.NDArray
            y: int
            z: Any

        @dataclasses.dataclass
        class Outer:
            a: str
            b: Inner

        synthesized = jax_util.synthesize_dataclass(Outer)

        self.assertEqual(
            synthesized,
            Outer(a="",
                  b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray),
                          y=0,
                          z=jax_util.LeafPlaceholder(Any))))  # type:ignore
Пример #4
0
    def test_synthesize_dataclass(self):
        @dataclasses.dataclass
        class Inner:
            x: jax_util.NDArray
            y: int
            z: Any

        @dataclasses.dataclass
        class Outer:
            a: str
            b: Inner  # pytype: disable=invalid-annotation  # enable-bare-annotations

        synthesized = jax_util.synthesize_dataclass(Outer)

        self.assertEqual(
            synthesized,
            Outer(a="",
                  b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray),
                          y=0,
                          z=jax_util.LeafPlaceholder(Any))))  # type:ignore
Пример #5
0
def build_one_pass_iterator_factory(
    tfrecord_path,
    example_type,
    truncate_at=None,
    skip_first=0,
):
    """Build a deterministic one-epoch iterator for unbatched examples.

  Args:
    tfrecord_path: Path to the TFRecord files to use. Can include a * glob
      pattern to load multiple files.
    example_type: Dataclass to use to deserialize the results.
    truncate_at: Number of examples to truncate the table at. Determines the
      effective size of the dataset.
    skip_first: Number of examples to skip at the beginnning of the dataset.

  Returns:
    Callable with no args that, when called, returns a new dataset iterator.
    This iterator produces train_util.ExampleWithMetadata objects, where epoch
    gives the
    number of times the factory function has been called. Each returned iterator
    will make exactly one pass through the dataset and then stop.
  """
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_path))
    dataset = dataset.skip(skip_first)
    if truncate_at:
        dataset = dataset.take(truncate_at)

    prototype_object = (0, jax_util.synthesize_dataclass(example_type))
    epoch = 0

    def factory():
        nonlocal epoch
        epoch += 1
        for item in dataset.as_numpy_iterator():
            ex_id, ex = flax.serialization.from_bytes(target=prototype_object,
                                                      encoded_bytes=item)
            yield train_util.ExampleWithMetadata(epoch, ex_id, ex)

    return factory