コード例 #1
0
    def save_fn():
      num_shards = len(self._single_device_savers)
      sharded_saves = []
      sharded_prefixes = []
      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
      last_device = None
      for shard, (device, saver) in enumerate(
          sorted(self._single_device_savers.items())):
        last_device = device
        with ops.device(saveable_object_util.set_cpu0(device)):
          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                          num_shards_tensor)
        sharded_prefixes.append(shard_prefix)
        with ops.device(device):
          # _SingleDeviceSaver will use the CPU device when necessary, but
          # initial read operations should be placed on the SaveableObject's
          # device.
          sharded_saves.append(saver.save(shard_prefix, options))

      with ops.control_dependencies(sharded_saves):
        # Merge on the io_device if specified, otherwise co-locates the merge op
        # with the last device used.
        merge_device = (
            options.experimental_io_device or
            saveable_object_util.set_cpu0(last_device))
        with ops.device(merge_device):
          # V2 format write path consists of a metadata merge step.  Once
          # merged, attempts to delete the temporary directory,
          # "<user-fed prefix>_temp".
          return gen_io_ops.merge_v2_checkpoints(
              sharded_prefixes, file_prefix, delete_old_dirs=True)
コード例 #2
0
        def save_fn():
            saved_prefixes = []
            # Save with the registered savers. These run before default savers due to
            # the API contract.
            for saver_name, (save_fn, _) in self._registered_savers.items():
                maybe_saved_prefixes = save_fn(registered_paths[saver_name])
                if maybe_saved_prefixes is not None:
                    flattened_saved_prefixes = nest.flatten(
                        maybe_saved_prefixes)
                    if not all(
                            tensor_util.is_tf_type(x)
                            and x.dtype == dtypes.string
                            for x in flattened_saved_prefixes):
                        raise ValueError(
                            "Registered saver must return a (maybe empty) list of "
                            f"string type tensors. Got {maybe_saved_prefixes}."
                        )
                    saved_prefixes.extend(flattened_saved_prefixes)

            # (Default saver) Save with single device savers.
            num_shards = len(self._single_device_savers)
            sharded_saves = []
            num_shards_tensor = constant_op.constant(num_shards,
                                                     name="num_shards")
            last_device = None
            for shard, (device, saver) in enumerate(
                    sorted(self._single_device_savers.items())):
                last_device = device
                with ops.device(saveable_object_util.set_cpu0(device)):
                    shard_prefix = sharded_filename(tmp_checkpoint_prefix,
                                                    shard, num_shards_tensor)
                saved_prefixes.append(shard_prefix)
                with ops.device(device):
                    # _SingleDeviceSaver will use the CPU device when necessary, but
                    # initial read operations should be placed on the SaveableObject's
                    # device.
                    sharded_saves.append(saver.save(shard_prefix, options))

            with ops.control_dependencies(sharded_saves):
                # Merge on the io_device if specified, otherwise co-locates the merge op
                # with the last device used.
                merge_device = (options.experimental_io_device
                                or saveable_object_util.set_cpu0(last_device))
                with ops.device(merge_device):
                    # V2 format write path consists of a metadata merge step.  Once
                    # merged, attempts to delete the temporary directory,
                    # "<user-fed prefix>_temp".
                    return gen_io_ops.merge_v2_checkpoints(
                        saved_prefixes, file_prefix, delete_old_dirs=True)
コード例 #3
0
  def __init__(self, saveable_objects):
    """Specify a list of `SaveableObject`s to save and restore.

    Args:
      saveable_objects: A list of `SaveableObject`s.
        Objects extending `SaveableObject` will be saved and restored, and
        objects extending `SaveableHook` will be called into at save and
        restore time.
    """
    self._before_save_callbacks = []
    self._after_restore_callbacks = []

    saveable_objects = list(saveable_objects)
    saveables_by_device = {}
    for saveable in saveable_objects:
      is_saveable = isinstance(saveable, saveable_object.SaveableObject)
      is_hook = isinstance(saveable, saveable_hook.SaveableHook)

      if not is_saveable and not is_hook:
        raise ValueError(
            "Expected a dictionary of SaveableObjects, got {}."
            .format(saveable))

      if is_hook:
        self._before_save_callbacks.append(saveable.before_save)
        self._after_restore_callbacks.append(saveable.after_restore)

      if is_saveable:
        host_device = saveable_object_util.set_cpu0(saveable.device)
        saveables_by_device.setdefault(host_device, []).append(saveable)

    self._single_device_savers = {
        device: _SingleDeviceSaver(saveables)
        for device, saveables in saveables_by_device.items()}
コード例 #4
0
    def restore(self, file_prefix):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.

    Returns:
      An operation which restores the `Saver`'s `SaveableObject`s when run, or
      None if executing eagerly.
    """
        restore_ops = []
        for saveable in self._saveable_objects:
            if saveable.device:
                device = saveable_object_util.set_cpu0(saveable.device)
            else:
                device = None
            with ops.device(device):
                tensors = []
                for spec in saveable.specs:
                    tensors.append(
                        io_ops.restore_v2(file_prefix, [spec.name],
                                          [spec.slice_spec], [spec.dtype])[0])
                restore_ops.append(
                    saveable.restore(tensors, restored_shapes=None))
        return control_flow_ops.group(restore_ops)
コード例 #5
0
    def __init__(self,
                 saveable_objects,
                 registered_savers=None,
                 call_with_mapped_captures=None):
        """Specify a list of `SaveableObject`s to save and restore.

    Args:
      saveable_objects: A list of `SaveableObject`s.
        Objects extending `SaveableObject` will be saved and restored, and
        objects extending `SaveableHook` will be called into at save and
        restore time.
      registered_savers: A dictionary mapping `registration.RegisteredSaver`
        namedtuples to a dictionary of named Trackables. The keys of the
        Trackable dictionary are string names that uniquely identify the
        Trackable in the checkpoint.
      call_with_mapped_captures: TODO
    """
        self._before_save_callbacks = []
        self._after_restore_callbacks = []

        saveable_objects = list(saveable_objects)
        saveables_by_device = {}
        for saveable in saveable_objects:
            is_saveable = isinstance(saveable, saveable_object.SaveableObject)
            is_hook = isinstance(saveable, saveable_hook.SaveableHook)

            if not is_saveable and not is_hook:
                raise ValueError(
                    f"Expected a dictionary of SaveableObjects, got {saveable}."
                )

            if is_hook:
                self._before_save_callbacks.append(saveable.before_save)
                self._after_restore_callbacks.append(saveable.after_restore)

            if is_saveable:
                host_device = saveable_object_util.set_cpu0(saveable.device)
                saveables_by_device.setdefault(host_device,
                                               []).append(saveable)

        self._single_device_savers = {
            device: _SingleDeviceSaver(saveables)
            for device, saveables in saveables_by_device.items()
        }

        self._registered_savers = {}
        if registered_savers:
            for registered_name, trackables in registered_savers.items():
                save_fn = _get_mapped_registered_save_fn(
                    registration.get_save_function(registered_name),
                    trackables, call_with_mapped_captures)
                restore_fn = _get_mapped_registered_restore_fn(
                    registration.get_restore_function(registered_name),
                    trackables, call_with_mapped_captures)
                self._registered_savers[registered_name] = (save_fn,
                                                            restore_fn)
コード例 #6
0
 def tf_function_restore():
   restore_ops = restore_fn()
   restore_tensors = {}
   # tf.functions must return tensors, thus we use control dependencies so
   # that we can return a tensor which depends on the given op.
   with ops.device(saveable_object_util.set_cpu0(first_device)):
     for name, op in restore_ops.items():
       with ops.control_dependencies([op]):
         restore_tensors[name] = array_ops.identity(file_prefix)
   return restore_tensors
コード例 #7
0
  def __init__(self,
               saveable_objects,
               registered_savers=None,
               call_with_mapped_captures=None):
    """Specify a list of `SaveableObject`s to save and restore.

    Args:
      saveable_objects: A list of `SaveableObject`s.
        Objects extending `SaveableObject` will be saved and restored.
      registered_savers: A dictionary mapping `registration.RegisteredSaver`
        namedtuples to a dictionary of named Trackables. The keys of the
        Trackable dictionary are string names that uniquely identify the
        Trackable in the checkpoint.
      call_with_mapped_captures: TODO
    """
    saveable_objects = list(saveable_objects)

    # Keep these two data structures so that we can map restored tensors to
    # the Trackable restore functions.
    self._keys_to_restore_fn = {}
    self._restore_fn_to_keys = {}

    # Extract serialized tensors and separate by device.
    tensors_by_device = {}  # device -> checkpoint key -> (slice_spec ->) tensor
    for saveable in saveable_objects:
      tensor_dict = saveable_object_util.saveable_object_to_tensor_dict(
          [saveable])
      restore_fn = saveable_object_util.saveable_object_to_restore_fn(
          [saveable])

      # Divide tensor_dict by device.
      for checkpoint_key, maybe_tensor in tensor_dict.items():
        if not isinstance(maybe_tensor, dict):
          # Make sure that maybe_tensor is structured as {slice_spec -> tensor}.
          maybe_tensor = {"": maybe_tensor}

        for slice_spec, tensor in maybe_tensor.items():
          if (checkpoint_key, slice_spec) in self._keys_to_restore_fn:
            raise ValueError(
                "Recieved multiple tensors with the same checkpoint key and "
                "slice spec. This is invalid because one will overwrite the "
                "other in the checkpoint. This indicates a bug in the "
                "Checkpoint key-generation.")
          self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn
          self._restore_fn_to_keys.setdefault(restore_fn, []).append(
              (checkpoint_key, slice_spec))

          host_device = saveable_object_util.set_cpu0(tensor.device)
          (tensors_by_device
           .setdefault(host_device, {})
           .setdefault(checkpoint_key, {})[slice_spec]) = tensor
    self._single_device_savers = {
        device: _SingleDeviceSaver(tensor_slice_dict)
        for device, tensor_slice_dict in tensors_by_device.items()}

    self._registered_savers = {}
    if registered_savers:
      for registered_name, trackables in registered_savers.items():
        save_fn = _get_mapped_registered_save_fn(
            registration.get_save_function(registered_name),
            trackables, call_with_mapped_captures)
        restore_fn = _get_mapped_registered_restore_fn(
            registration.get_restore_function(registered_name),
            trackables, call_with_mapped_captures)
        self._registered_savers[registered_name] = (save_fn, restore_fn)
コード例 #8
0
  def save(self, file_prefix):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    sharded_suffix = "_temp_%s/part" % uuid.uuid4().hex

    with ops.device("cpu:0"):
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])

    num_shards = len(self._single_device_savers)
    sharded_saves = []
    sharded_prefixes = []
    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    last_device = None
    for shard, (device, saver) in enumerate(
        sorted(self._single_device_savers.items())):
      last_device = device
      with ops.device(saveable_object_util.set_cpu0(device)):
        shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                        num_shards_tensor)
      sharded_prefixes.append(shard_prefix)
      with ops.device(device):
        # _SingleDeviceSaver will use the CPU device when necessary, but initial
        # read operations should be placed on the SaveableObject's device.
        sharded_saves.append(saver.save(shard_prefix))

    with ops.control_dependencies(sharded_saves):
      # Co-locates the merge step with the last device.
      with ops.device(saveable_object_util.set_cpu0(last_device)):
        # V2 format write path consists of a metadata merge step.  Once merged,
        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
        return gen_io_ops.merge_v2_checkpoints(
            sharded_prefixes, file_prefix, delete_old_dirs=True)
コード例 #9
0
    def save(self, file_prefix, options=None):
        """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
        options = options or checkpoint_options.CheckpointOptions()
        for callback in self._before_save_callbacks:
            callback()

        # IMPLEMENTATION DETAILS: most clients should skip.
        #
        # Suffix for any well-formed "checkpoint_prefix", when sharded.
        # Transformations:
        # * Users pass in "save_path" in save() and restore().  Say "myckpt".
        # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
        #
        # Example:
        #   During runtime, a temporary directory is first created, which contains
        #   files
        #
        #     <train dir>/myckpt_temp/
        #        part-?????-of-?????{.index, .data-00000-of-00001}
        #
        #   Before .save() finishes, they will be (hopefully, atomically) renamed to
        #
        #     <train dir>/
        #        myckpt{.index, .data-?????-of-?????}
        #
        #   Filesystems with eventual consistency (such as S3), don't need a
        #   temporary location. Using a temporary directory in those cases might
        #   cause situations where files are not available during copy.
        #
        # Users only need to interact with the user-specified prefix, which is
        # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
        # prefix directly, instead of any physical pathname.  (On failure and
        # subsequent restore, an outdated and orphaned temporary directory can be
        # safely removed.)
        with ops.device("CPU"):
            sharded_suffix = array_ops.where(
                string_ops.regex_full_match(file_prefix, "^s3://.*"),
                constant_op.constant(".part"),
                constant_op.constant("_temp_%s/part" % uuid.uuid4().hex))
            tmp_checkpoint_prefix = string_ops.string_join(
                [file_prefix, sharded_suffix])

        num_shards = len(self._single_device_savers)
        sharded_saves = []
        sharded_prefixes = []
        num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
        last_device = None
        for shard, (device, saver) in enumerate(
                sorted(self._single_device_savers.items())):
            last_device = device
            with ops.device(saveable_object_util.set_cpu0(device)):
                shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                                num_shards_tensor)
            sharded_prefixes.append(shard_prefix)
            with ops.device(device):
                # _SingleDeviceSaver will use the CPU device when necessary, but initial
                # read operations should be placed on the SaveableObject's device.
                sharded_saves.append(saver.save(shard_prefix, options))

        with ops.control_dependencies(sharded_saves):
            # Merge on the io_device if specified, otherwise co-locates the merge op
            # with the last device used.
            merge_device = (options.experimental_io_device
                            or saveable_object_util.set_cpu0(last_device))
            with ops.device(merge_device):
                # V2 format write path consists of a metadata merge step.  Once merged,
                # attempts to delete the temporary directory, "<user-fed prefix>_temp".
                return gen_io_ops.merge_v2_checkpoints(sharded_prefixes,
                                                       file_prefix,
                                                       delete_old_dirs=True)