Exemple #1
0
def save(self,
         path,
         compression=None,
         shard_func=None,
         checkpoint_args=None):
  """Implements the save function and checkpoint functionality."""
  if context.executing_eagerly() and checkpoint_args:
    save_dataset = _SaveDataset(self, path, shard_func, compression)
    save_iterator = iter(save_dataset)

    if "checkpoint" in checkpoint_args:
      raise ValueError(
          "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed "
          "to include 'checkpoint'."
      )
    checkpoint = checkpoint_lib.Checkpoint(iterator=save_iterator)
    checkpoint_args["checkpoint"] = checkpoint
    manager = checkpoint_management.CheckpointManager(**checkpoint_args)
    checkpoint.restore(manager.latest_checkpoint)

    for _ in enumerate(save_iterator):
      if "step_counter" in checkpoint_args:
        checkpoint_args["step_counter"].assign_add(delta=1)
      manager.save(check_interval=True)
  else:
    dataset, shard_func, use_shard_func, path = set_save_dataset_attributes(
        self, shard_func, path)
    ged_ops.save_dataset(
        dataset._variant_tensor,   # pylint: disable=protected-access
        path=path,
        shard_func_other_args=shard_func.captured_inputs,
        compression=compression,
        shard_func=shard_func,
        use_shard_func=use_shard_func)
Exemple #2
0
def save(dataset, path, compression=None, shard_func=None):
    """Saves the content of the given dataset.

  Example usage:

  >>> import tempfile
  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
  >>> # Save a dataset
  >>> dataset = tf.data.Dataset.range(2)
  >>> tf.data.experimental.save(dataset, path)
  >>> new_dataset = tf.data.experimental.load(path)
  >>> for elem in new_dataset:
  ...   print(elem)
  tf.Tensor(0, shape=(), dtype=int64)
  tf.Tensor(1, shape=(), dtype=int64)

  The saved dataset is saved in multiple file "shards". By default, the dataset
  output is divided to shards in a round-robin fashion but custom sharding can
  be specified via the `shard_func` function. For example, you can save the
  dataset to using a single shard as follows:

  ```python
  dataset = make_dataset()
  def custom_shard_func(element):
    return 0
  dataset = tf.data.experimental.save(
      path="/path/to/data", ..., shard_func=custom_shard_func)
  ```

  NOTE: The directory layout and file format used for saving the dataset is
  considered an implementation detail and may change. For this reason, datasets
  saved through `tf.data.experimental.save` should only be consumed through
  `tf.data.experimental.load`, which is guaranteed to be backwards compatible.

  Args:
    dataset: The dataset to save.
    path: Required. A directory to use for saving the dataset.
    compression: Optional. The algorithm to use to compress data when writing
      it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
    shard_func: Optional. A function to control the mapping of dataset elements
      to file shards. The function is expected to map elements of the input
      dataset to int64 shard IDs. If present, the function will be traced and
      executed as graph computation.
  """
    if (context.executing_eagerly()
            and compat.forward_compatible(2021, 6, 29)):
        save_dataset = _SaveDataset(dataset, path, shard_func, compression)
        for _ in save_dataset:
            pass
    else:
        dataset, shard_func, use_shard_func, path = _set_save_dataset_attributes(
            dataset, shard_func, path)
        gen_experimental_dataset_ops.save_dataset(
            dataset._variant_tensor,  # pylint: disable=protected-access
            path=path,
            shard_func_other_args=shard_func.captured_inputs,
            compression=compression,
            shard_func=shard_func,
            use_shard_func=use_shard_func)
Exemple #3
0
def save(dataset,
         path,
         compression=None,
         shard_func=None,
         checkpoint_args=None):
    """Saves the content of the given dataset.

  Example usage:

  >>> import tempfile
  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
  >>> # Save a dataset
  >>> dataset = tf.data.Dataset.range(2)
  >>> tf.data.experimental.save(dataset, path)
  >>> new_dataset = tf.data.experimental.load(path)
  >>> for elem in new_dataset:
  ...   print(elem)
  tf.Tensor(0, shape=(), dtype=int64)
  tf.Tensor(1, shape=(), dtype=int64)

  The saved dataset is saved in multiple file "shards". By default, the dataset
  output is divided to shards in a round-robin fashion but custom sharding can
  be specified via the `shard_func` function. For example, you can save the
  dataset to using a single shard as follows:

  ```python
  dataset = make_dataset()
  def custom_shard_func(element):
    return 0
  dataset = tf.data.experimental.save(
      path="/path/to/data", ..., shard_func=custom_shard_func)
  ```

  To enable checkpointing, pass in `checkpoint_args` to the `save` method
  as follows:

  ```python
  dataset = tf.data.Dataset.range(100)
  save_dir = "..."
  checkpoint_prefix = "..."
  step_counter = tf.Variable(0, trainable=False)
  checkpoint_args = {
    "checkpoint_interval": 50,
    "step_counter": step_counter,
    "directory": checkpoint_prefix,
    "max_to_keep": 20,
  }
  dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)
  ```

  NOTE: The directory layout and file format used for saving the dataset is
  considered an implementation detail and may change. For this reason, datasets
  saved through `tf.data.experimental.save` should only be consumed through
  `tf.data.experimental.load`, which is guaranteed to be backwards compatible.

  Args:
    dataset: The dataset to save.
    path: Required. A directory to use for saving the dataset.
    compression: Optional. The algorithm to use to compress data when writing
      it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
    shard_func: Optional. A function to control the mapping of dataset elements
      to file shards. The function is expected to map elements of the input
      dataset to int64 shard IDs. If present, the function will be traced and
      executed as graph computation.
    checkpoint_args: Optional args for checkpointing which will be passed into
      the `tf.train.CheckpointManager`. If `checkpoint_args` are not specified,
      then checkpointing will not be performed. The `save()` implementation
      creates a `tf.train.Checkpoint` object internally, so users should not
      set the `checkpoint` argument in `checkpoint_args`.
  Raises:
    ValueError if `checkpoint` is passed into `checkpoint_args`.
  """
    if (context.executing_eagerly() and checkpoint_args
            and compat.forward_compatible(2021, 6, 29)):
        save_dataset = _SaveDataset(dataset, path, shard_func, compression)
        save_iterator = iter(save_dataset)

        if "checkpoint" in checkpoint_args:
            raise ValueError(
                "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed "
                "to include 'checkpoint'.")
        checkpoint = tracking.util.Checkpoint(iterator=save_iterator)
        checkpoint_args["checkpoint"] = checkpoint
        manager = checkpoint_management.CheckpointManager(**checkpoint_args)
        checkpoint.restore(manager.latest_checkpoint)

        for _ in enumerate(save_iterator):
            if "step_counter" in checkpoint_args:
                checkpoint_args["step_counter"].assign_add(delta=1)
            manager.save(check_interval=True)
    else:
        dataset, shard_func, use_shard_func, path = _set_save_dataset_attributes(
            dataset, shard_func, path)
        gen_experimental_dataset_ops.save_dataset(
            dataset._variant_tensor,  # pylint: disable=protected-access
            path=path,
            shard_func_other_args=shard_func.captured_inputs,
            compression=compression,
            shard_func=shard_func,
            use_shard_func=use_shard_func)
Exemple #4
0
def save(dataset, path, compression=None, shard_func=None):
    """Saves the content of the given dataset.

  Example usage:

  >>> import tempfile
  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
  >>> # Save a dataset
  >>> dataset = tf.data.Dataset.range(2)
  >>> tf.data.experimental.save(dataset, path)
  >>> new_dataset = tf.data.experimental.load(path,
  ...     tf.TensorSpec(shape=(), dtype=tf.int64))
  >>> for elem in new_dataset:
  ...   print(elem)
  tf.Tensor(0, shape=(), dtype=int64)
  tf.Tensor(1, shape=(), dtype=int64)

  The saved dataset is saved in multiple file "shards". By default, the dataset
  output is divided to shards in a round-robin fashion but custom sharding can
  be specified via the `shard_func` function. For example, you can save the
  dataset to using a single shard as follows:

  ```python
  dataset = make_dataset()
  def custom_shard_func(element):
    return 0
  dataset = tf.data.experimental.save(
      path="/path/to/data", ..., shard_func=custom_shard_func)
  ```

  NOTE: The directory layout and file format used for saving the dataset is
  considered an implementation detail and may change. For this reason, datasets
  saved through `tf.data.experimental.save` should only be consumed through
  `tf.data.experimental.load`, which is guaranteed to be backwards compatible.

  Args:
    dataset: The dataset to save.
    path: Required. A directory to use for saving the dataset.
    compression: Optional. The algorithm to use to compress data when writing
      it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
    shard_func: Optional. A function to control the mapping of dataset elements
      to file shards. The function is expected to map elements of the input
      dataset to int64 shard IDs. If present, the function will be traced and
      executed as graph computation.
  """

    if shard_func is None:
        use_shard_func = False
        shard_func = lambda *x: None  # a dummy function that will not be used
    else:
        use_shard_func = True

    wrapped_func = dataset_ops.StructuredFunctionWrapper(
        shard_func,
        "save()",
        input_structure=dataset.element_spec,
        add_to_graph=False)

    coder = nested_structure_coder.StructureCoder()
    encoded = coder.encode_structure(dataset.element_spec)
    gfile.MakeDirs(path)
    with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
        f.write(encoded.SerializeToString())

    path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
    shard_func = wrapped_func.function
    shard_func.add_to_graph(ops.get_default_graph())

    # pylint: disable=protected-access
    dataset = dataset._apply_options()
    gen_experimental_dataset_ops.save_dataset(
        dataset._variant_tensor,
        path=path,
        shard_func_other_args=shard_func.captured_inputs,
        compression=compression,
        shard_func=shard_func,
        use_shard_func=use_shard_func)