Example #1
0
def sharded_save(
    mesh: layout_lib.Mesh,
    file_prefix: Union[str, ops.Tensor],
    tensor_names: Union[List[str], ops.Tensor],
    shape_and_slices: Union[List[str], ops.Tensor],
    tensors: List[Union[ops.Tensor, tf_variables.Variable]],
):
    """Saves given named tensor slices in a sharded, multi-client safe fashion.

  The method makes sure the checkpoint directory state is correct in a sharded
  mutli-client saving. Namely, we place a barrier after SaveV2 to make sure
  every client has done writing the files. And another one after
  MergeV2Checkpoints to make sure all Metadata is properly merged.

  Upon existing, the checkpoint is completed and the all directory operations
  are done.

  Args:
    mesh: The Mesh that contains the Tensors to save.
    file_prefix: The prefix of checkpoint.
    tensor_names: a list of tensor names used in save op.
    shape_and_slices: a list of shape and slice specification used in save op.
      The only supported value is "" as we don't support distributed saving with
      slices yet.
    tensors: a list of tensors used in save op. The order should match
      tensor_names.

  Returns:
    A MergeV2Checkpoints op that merged all Metadata.
  """
    with ops.device(api.device_name()):
        io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors)

    # Query generated shards and generate MergeV2.
    generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix],
                                      tensor_names, shape_and_slices, tensors)
    # api.py is still visible to external users but the _global_barrier() isn't
    # intended for public usage.
    # Once we locked down api.py visibility, we shall be able to make the `_`
    # prefix on these APIs go away.

    # Make sure all clients have written the files
    _global_barrier(mesh.host_mesh(), 'SaveV2')  # pylint: disable=protected-access

    with ops.device(api.device_name()):
        merge_op = io_ops.MergeV2Checkpoints(
            checkpoint_prefixes=generated_shards,
            destination_prefix=file_prefix,
            delete_old_dirs=True)

    # Make sure first device in first host has finished merge.
    # pylint: disable=protected-access
    _global_barrier(mesh.host_mesh(), 'MergeV2Checkpoints')
    # pylint: enable=protected-access

    return merge_op
Example #2
0
def main(argv):
    del argv  # Unused.

    dataset = getattr(tf.keras.datasets, FLAGS.dataset)
    (x_train, y_train), (x_test, y_test) = dataset.load_data()

    def wrap(val):
        dtype = tf.as_dtype(val.dtype)
        assert dtype != tf.string  # tf.string is not supported by py_func.
        return tf.py_func(lambda: val, [], dtype)

    out_prefix = FLAGS.out or os.path.join("/tmp", FLAGS.dataset,
                                           FLAGS.dataset)
    tf.logging.info("Save %s dataset to %s ckpt." %
                    (FLAGS.dataset, out_prefix))

    with tf.Session() as sess:
        sess.run(
            io_ops.save_v2(
                prefix=out_prefix,
                tensor_names=["x_train", "y_train", "x_test", "y_test"],
                shape_and_slices=[""] * 4,
                tensors=[
                    wrap(x_train),
                    wrap(y_train),
                    wrap(x_test),
                    wrap(y_test)
                ]))
 def testRelativePath(self):
   os.chdir(self.get_temp_dir())
   self.evaluate(io_ops.save_v2(
       "ckpt", ["x"], [""], [constant_op.constant(100.)]))
   self.assertAllEqual([100.],
                       self.evaluate(io_ops.restore_v2(
                           "ckpt", ["x"], [""], [dtypes.float32])))
Example #4
0
def WriteNpArrays(file_prefix, nmap):
  """Writes a NestedMap of numpy arrays into a TF checkpoint.

  Args:
    file_prefix: A TF checkpoint filename prefix.
    nmap: A NestedMap of numpy arrays.
  """
  g = tf.Graph()
  with g.as_default():

    def Wrap(val):
      dtype = tf.as_dtype(val.dtype)
      assert dtype != tf.string  # tf.string is not supported by py_func.
      return tf.py_func(lambda: val, [], dtype)

    names, values = [], []
    for k, v in nmap.FlattenItems():
      names.append(k)
      assert isinstance(v, np.ndarray)
      values.append(Wrap(v))

    save = io_ops.save_v2(
        prefix=file_prefix,
        tensor_names=names,
        tensors=values,
        shape_and_slices=[""] * len(names))

  with tf.Session(graph=g) as sess:
    sess.run(save)
Example #5
0
    def _AddShardedSaveOps(self, variables, checkpoint_prefix, var_key_fn):
        """Adds per-device save ops to save `variables` to `checkpoint_prefix`."""
        with self._var_graph.as_default():
            per_device = collections.defaultdict(lambda: [])
            for var in variables:
                per_device[var.device].append(var)

            tmp_save_prefix = tf.strings.join(
                [checkpoint_prefix, "_temp/part"])
            num_shards = tf.constant(len(per_device))
            sharded_saves = []
            sharded_prefixes = []

            for shard, (device, var_list) in enumerate(per_device.items()):
                with self._var_graph.device(device):
                    sharded_filename = gen_io_ops.sharded_filename(
                        tmp_save_prefix, shard, num_shards)
                    sharded_prefixes.append(sharded_filename)
                    save_op = io_ops.save_v2(
                        prefix=sharded_filename,
                        tensor_names=[var_key_fn(v) for v in var_list],
                        tensors=[v.read_value() for v in var_list],
                        shape_and_slices=[""] * len(var_list))
                    sharded_saves.append(save_op)

            with tf.control_dependencies(sharded_saves):
                return gen_io_ops.merge_v2_checkpoints(sharded_prefixes,
                                                       checkpoint_prefix,
                                                       delete_old_dirs=True)
Example #6
0
 def save_op(self, filename_tensor, saveables):
     tensor_names = []
     tensors = []
     tensor_slices = []
     for saveable in saveables:
         for spec in saveable.specs:
             if spec.name.startswith('replicated_'):
                 if spec.name.startswith(
                         'replicated_0') or 'avg' in spec.name:
                     tensor_names.append('/'.join(spec.name.split('/')[1:]))
                     tensors.append(spec.tensor)
                     tensor_slices.append(spec.slice_spec)
             else:
                 tensor_names.append(spec.name)
                 tensors.append(spec.tensor)
                 tensor_slices.append(spec.slice_spec)
     if self._write_version == saver_pb2.SaverDef.V1:
         return io_ops._save(filename=filename_tensor,
                             tensor_names=tensor_names,
                             tensors=tensors,
                             tensor_slices=tensor_slices)
     elif self._write_version == saver_pb2.SaverDef.V2:
         # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
         # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
         return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
                               tensors)
     else:
         raise RuntimeError("Unexpected write_version: " +
                            self._write_version)
Example #7
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensors = []
    slice_specs = []
    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec, tensor in tensor_slices.items():
        if isinstance(tensor, saveable_object.SaveSpec):
          tensor_value = tensor.tensor
          # A tensor value of `None` indicates that this SaveableObject gets
          # recorded in the object graph, but that no value is saved in the
          # checkpoint.
          if tensor_value is not None:
            tensor_names.append(tensor.name)
            tensors.append(tensor_value)
            slice_specs.append(tensor.slice_spec)
        else:
          tensor_names.append(checkpoint_key)
          tensors.append(tensor)
          slice_specs.append(slice_spec)
    save_device = options.experimental_io_device or "cpu:0"
    with ops.device(save_device):
      return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
Example #8
0
    def save(self, file_prefix):
        """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
    Returns:
      A scalar string Tensor containing `file_prefix` with control dependencies
      on the save ops.
    """
        tensor_names = []
        tensors = []
        tensor_slices = []
        for saveable in self._saveable_objects:
            for spec in saveable.specs:
                tensor_names.append(spec.name)
                tensors.append(spec.tensor)
                tensor_slices.append(spec.slice_spec)
        io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
        return file_prefix
Example #9
0
 def _BuildSave(self):
   """Builds save ops."""
   self._save_global_step = py_utils.GetGlobalStep()
   self._save_prefix = tf.strings.join([
       self._logdir_ph, "/ckpt-",
       tf.as_string(self._save_global_step, width=8, fill="0")
   ])
   self._save_op = io_ops.save_v2(
       prefix=self._save_prefix,
       tensor_names=[_VarKey(v) for v in self._vars],
       tensors=[v.read_value() for v in self._vars],
       shape_and_slices=[""] * len(self._vars))
Example #10
0
def FakeMnistData(tmpdir, train_size=60000, test_size=10000):
    """Fake Mnist data for unit tests."""
    data_path = os.path.join(tmpdir, 'ckpt')
    with tf.Graph().as_default():
        with tf.Session() as sess:
            x_train = tf.ones((train_size, 28, 28, 1), dtype=tf.uint8)
            y_train = tf.ones((train_size), dtype=tf.uint8)
            x_test = tf.ones((test_size, 28, 28, 1), dtype=tf.uint8)
            y_test = tf.ones((test_size), dtype=tf.uint8)
            sess.run(
                io_ops.save_v2(data_path,
                               ['x_train', 'y_train', 'x_test', 'y_test'],
                               [''] * 4, [x_train, y_train, x_test, y_test]))
    return data_path
Example #11
0
def FakeMnistData(tmpdir, train_size=60000, test_size=10000):
    """Fake Mnist data for unit tests."""
    data_path = os.path.join(tmpdir, 'ckpt')
    with tf.Graph().as_default():
        tf.random.set_seed(91)
        with tf.Session() as sess:
            sess.run(
                io_ops.save_v2(
                    data_path,
                    tensor_names=['x_train', 'y_train', 'x_test', 'y_test'],
                    shape_and_slices=['', '', '', ''],
                    tensors=[
                        _GetRandomImages(train_size),
                        _GetRandomLabels(train_size),
                        _GetRandomImages(test_size),
                        _GetRandomLabels(test_size)
                    ]))
    return data_path
  def save(self, file_prefix):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in self._saveable_objects:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)
    with ops.device("cpu:0"):
      return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
Example #13
0
  def save(self, file_prefix):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
    Returns:
      A scalar string Tensor containing `file_prefix` with control dependencies
      on the save ops.
    """
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in self._saveable_objects:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)
    with ops.device("cpu:0"):
      with ops.control_dependencies([io_ops.save_v2(
          file_prefix, tensor_names, tensor_slices, tensors)]):
        return array_ops.identity(file_prefix)
Example #14
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in self._saveable_objects:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)
    save_device = options.experimental_io_device or "cpu:0"
    with ops.device(save_device):
      return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
Example #15
0
def main(argv):
    del argv  # Unused.

    dataset = getattr(tf.keras.datasets, FLAGS.dataset)
    (x_train, y_train), (x_test, y_test) = dataset.load_data()

    def wrap(val):
        dtype = tf.as_dtype(val.dtype)
        assert dtype != tf.string  # tf.string is not supported by py_func.
        return tf.py_func(lambda: val, [], dtype)

    with tf.Session() as sess:
        sess.run(
            io_ops.save_v2(
                prefix=FLAGS.out if FLAGS.out else "/tmp/" + FLAGS.dataset,
                tensor_names=["x_train", "y_train", "x_test", "y_test"],
                shape_and_slices=[""] * 4,
                tensors=[
                    wrap(x_train),
                    wrap(y_train),
                    wrap(x_test),
                    wrap(y_test)
                ]))
Example #16
0
 def save_fn(trackables, file_prefix):
     tensor_names, shapes_and_slices, tensors, _ = _get_tensors(
         trackables)
     io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices,
                    tensors)
     return file_prefix
Example #17
0
def save_stacks_and_parts(trackables, file_prefix):
  """Save stack and part objects to a checkpoint shard."""
  tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
  io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
  return file_prefix