Exemplo n.º 1
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensors = []
    slice_specs = []
    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec, tensor in tensor_slices.items():
        if isinstance(tensor, saveable_object.SaveSpec):
          tensor_value = tensor.tensor
          # A tensor value of `None` indicates that this SaveableObject gets
          # recorded in the object graph, but that no value is saved in the
          # checkpoint.
          if tensor_value is not None:
            tensor_names.append(tensor.name)
            tensors.append(tensor_value)
            slice_specs.append(tensor.slice_spec)
        else:
          tensor_names.append(checkpoint_key)
          tensors.append(tensor)
          slice_specs.append(slice_spec)
    save_device = options.experimental_io_device or "cpu:0"
    with ops.device(save_device):
      return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
Exemplo n.º 2
0
    def restore(self, file_prefix, options=None):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
        options = options or checkpoint_options.CheckpointOptions()
        restore_specs = []
        tensor_structure = []
        for saveable in self._saveable_objects:
            saveable_tensor_structure = []
            tensor_structure.append(saveable_tensor_structure)
            for spec in saveable.specs:
                saveable_tensor_structure.append(spec.name)
                restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
        tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
        restore_device = options.experimental_io_device or "cpu:0"
        with ops.device(restore_device):
            restored_tensors = io_ops.restore_v2(file_prefix, tensor_names,
                                                 tensor_slices, tensor_dtypes)
        structured_restored_tensors = nest.pack_sequence_as(
            tensor_structure, restored_tensors)
        restore_ops = {}
        for saveable, restored_tensors in zip(self._saveable_objects,
                                              structured_restored_tensors):
            restore_ops[saveable.name] = saveable.restore(restored_tensors,
                                                          restored_shapes=None)
        return restore_ops
 def setUp(self):
   super(SaverTest, self).setUp()
   cpus = config.list_physical_devices("CPU")
   # Set 3 virtual CPUs
   config.set_logical_device_configuration(cpus[0], [
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration()
   ])
   self.local_options = checkpoint_options.CheckpointOptions(
       experimental_io_device=LOCALHOST)
Exemplo n.º 4
0
    def restore(self, file_prefix, options=None):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      When not run eagerly or when saving on a single device, returns a
      dictionary mapping from SaveableObject names to restore operations;
      otherwise, returns an empty dict.
    """
        options = options or checkpoint_options.CheckpointOptions()

        def restore_fn():
            restore_ops = {}
            # Sort by device name to avoid propagating non-deterministic dictionary
            # ordering in some Python versions.
            for device, saver in sorted(self._single_device_savers.items()):
                with ops.device(device):
                    restore_ops.update(saver.restore(file_prefix, options))
            for _, (_, restore_fn) in self._registered_savers.items():
                restore_fn(file_prefix)
            return restore_ops

        # Since this will causes a function re-trace on each restore, limit this to
        # cases where it is needed: eager and when there are multiple tasks/single
        # device savers. Note that the retrace is needed to ensure we pickup the
        # latest values of options like experimental_io_device.
        if context.executing_eagerly() and len(self._single_device_savers) > 1:

            @def_function.function(jit_compile=False)
            def tf_function_restore():
                restore_fn()
                return {}

            restore_ops = tf_function_restore()
        else:
            restore_ops = restore_fn()

        for callback in self._after_restore_callbacks:
            callback()

        return restore_ops
Exemplo n.º 5
0
  def restore(self, file_prefix, options=None):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor).
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensor_dtypes = []
    slice_specs = []

    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec, tensor in tensor_slices.items():
        tensor_dtypes.append(tensor.dtype)
        if isinstance(tensor, saveable_object.SaveSpec):
          slice_specs.append(tensor.slice_spec)
          tensor_names.append(tensor.name)
        else:
          slice_specs.append(slice_spec)
          tensor_names.append(checkpoint_key)

    restore_device = options.experimental_io_device or "cpu:0"
    with ops.device(restore_device):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, slice_specs, tensor_dtypes)

    restored_tensor_dict = {}
    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
      for slice_spec in tensor_slices:
        restored_tensor = restored_tensors.pop(0)
        restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = (
            restored_tensor)
    return restored_tensor_dict
Exemplo n.º 6
0
  def restore(self, file_prefix, options=None):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      When not run eagerly or when saving on a single device, returns a
      dictionary mapping from SaveableObject names to restore operations;
      otherwise, returns an empty dict.
    """
    options = options or checkpoint_options.CheckpointOptions()

    def restore_fn():
      restore_fn_inputs = {}
      restore_fn_input_count = {
          fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()}

      restore_ops = {}
      # Sort by device name to avoid propagating non-deterministic dictionary
      # ordering in some Python versions.
      for device, saver in sorted(self._single_device_savers.items()):
        with ops.device(device):
          # Load values from checkpoint
          restored_tensor_dict = saver.restore(file_prefix, options)

          # Map restored tensors to the corresponding restore_fn, and see if all
          # inputs have all been loaded. Call `restore_fn` if that is the case.
          for checkpoint_key, slice_and_tensor in restored_tensor_dict.items():
            for slice_spec, tensor in slice_and_tensor.items():
              restore_fn = self._keys_to_restore_fn[(checkpoint_key,
                                                     slice_spec)]
              (restore_fn_inputs
               .setdefault(restore_fn, {})
               .setdefault(checkpoint_key, {})[slice_spec]) = tensor
              restore_fn_input_count[restore_fn] -= 1

              if restore_fn_input_count[restore_fn] == 0:
                ret = restore_fn(restore_fn_inputs[restore_fn])
                if isinstance(ret, dict):
                  restore_ops.update(ret)
      # Run registered restore methods after the default restore ops.
      for _, (_, restore_fn) in self._registered_savers.items():
        restore_fn(file_prefix)
      return restore_ops

    restore_device = options.experimental_io_device or "cpu:0"

    # Since this will causes a function re-trace on each restore, limit this to
    # cases where it is needed: eager and when there are multiple tasks/single
    # device savers. Note that the retrace is needed to ensure we pickup the
    # latest values of options like experimental_io_device.
    if context.executing_eagerly() and (len(self._single_device_savers) > 1 or
                                        options.experimental_io_device):
      @def_function.function(jit_compile=False)
      def tf_function_restore():
        restore_fn()
        return {}

      with ops.device(restore_device):
        restore_ops = tf_function_restore()
    else:
      restore_ops = restore_fn()

    return restore_ops
Exemplo n.º 7
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()

    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    #   Filesystems with eventual consistency (such as S3), don't need a
    #   temporary location. Using a temporary directory in those cases might
    #   cause situations where files are not available during copy.
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    with ops.device("CPU"):
      sharded_suffix = array_ops.where(
          string_ops.regex_full_match(file_prefix, "^s3://.*"),
          constant_op.constant(".part"),
          constant_op.constant("_temp/part"))
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])
      registered_paths = {
          saver_name: registered_saver_filename(file_prefix, saver_name)
          for saver_name in self._registered_savers
      }

    def save_fn():
      saved_prefixes = []
      # Save with the registered savers. These run before default savers due to
      # the API contract.
      for saver_name, (save_fn, _) in self._registered_savers.items():
        maybe_saved_prefixes = save_fn(registered_paths[saver_name])
        if maybe_saved_prefixes is not None:
          flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
          if not all(
              tensor_util.is_tf_type(x) and x.dtype == dtypes.string
              for x in flattened_saved_prefixes):
            raise ValueError(
                "Registered saver must return a (maybe empty) list of "
                f"string type tensors. Got {maybe_saved_prefixes}.")
          saved_prefixes.extend(flattened_saved_prefixes)

      # (Default saver) Save with single device savers.
      num_shards = len(self._single_device_savers)
      sharded_saves = []
      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
      last_device = None
      for shard, (device, saver) in enumerate(
          sorted(self._single_device_savers.items())):
        last_device = device
        with ops.device(saveable_object_util.set_cpu0(device)):
          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                          num_shards_tensor)
        saved_prefixes.append(shard_prefix)
        with ops.device(device):
          # _SingleDeviceSaver will use the CPU device when necessary, but
          # initial read operations should be placed on the SaveableObject's
          # device.
          sharded_saves.append(saver.save(shard_prefix, options))

      with ops.control_dependencies(sharded_saves):
        # Merge on the io_device if specified, otherwise co-locates the merge op
        # with the last device used.
        merge_device = (
            options.experimental_io_device or
            saveable_object_util.set_cpu0(last_device))
        with ops.device(merge_device):
          # V2 format write path consists of a metadata merge step.  Once
          # merged, attempts to delete the temporary directory,
          # "<user-fed prefix>_temp".
          return gen_io_ops.merge_v2_checkpoints(
              saved_prefixes, file_prefix, delete_old_dirs=True)

    # Since this will causes a function re-trace on each save, limit this to the
    # cases where it is needed: eager and when there are multiple tasks/single
    # device savers. Note that the retrace is needed to ensure we pickup the
    # latest values of options like experimental_io_device.
    if context.executing_eagerly() and len(self._single_device_savers) > 1:
      # Explicitly place the identity op on the first device.
      @def_function.function(jit_compile=False)
      def tf_function_save():
        save_fn()
      tf_function_save()
    else:
      return save_fn()
Exemplo n.º 8
0
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status
Exemplo n.º 9
0
def load_partial(export_dir, filters, tags=None, options=None):
    """Partially load a SavedModel (saved from V2).

  Similar to `tf.saved_model.load`, but with an additional argument that
  lets you specify which nodes to load.
  `tf.saved_model.load_partial(export_dir, ["root"])` and
  `tf.saved_model.load(export_dir)` are equivalent.

  Note: This only works for SavedModels saved with TensorFlow V2 from
  `tf.saved_model.save` or Keras. This will not load SavedModels save from
  the Estimator API.

  In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
  The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
  layers, etc.) and edges that are the name of the attributes connecting the
  objects.

  *Example 1*

  ```
  model = tf.Module()
  model.child_layer = tf.Module()
  model.child_layer.v = tf.Variable(5.)
  tf.saved_model.save(model, '/tmp/model')
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   ['root.child_layer', 'root.child_layer.v'])
  loaded['root.child_layer'].v.numpy()
  5.
  loaded['root.child_layer'].v is loaded['root.child_layer.v']
  True

  *Example 2*
  model = tf.Module()
  model.child_layer = tf.Module()
  model.child_layer.v = tf.Variable(5.)
  >>>
  tf.saved_model.save(model, '/tmp/model')
  # Create a variable
  new_variable = tf.Variable(0.)
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   {'root.child_layer': None, 'root.child_layer.v': new_variable})
  loaded['root.child_layer'].v.numpy()
  5.
  new_variable.numpy()
  5.
  ```

  **Loading under different distribution strategies**
  You can load different parts of the model under different distribution
  strategies. Note that this is very experimental so use with care.

  ```
  model = tf.Module()
  model.layer_1 = tf.Module()
  model.layer_1.v = tf.Variable(5.)
  model.layer_2 = tf.Module()
  model.layer_2.v = tf.Variable(7.)
  tf.saved_model.save(model, '/tmp/model')
  # Load with no strategy
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   ['root.layer_1'])
  loaded['root.layer_1'].v
  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
  ...   loaded2 = tf.__internal__.saved_model.load_partial(
  ...     '/tmp/model',
  ...     ['root.layer_2'])
  loaded2['root.layer_2'].v
  MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
  }
  ```

  Args:
    export_dir: The SavedModel directory to load from.
    filters: A list or dictionary where each element or key is a string
      path to nodes that should be loaded. Node paths consist of all the child
      attribute names to reach that node in the form: `root.{attribute_name}`.
      The loader will load all of the specified nodes and their recursive
      descendants. When this option is defined, the loader will return a
      dictionary mapping the node paths to the loaded objects.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.save`.
    options: `tf.saved_model.LoadOptions` object that specifies options for
      loading.

  Returns:
    A dictionary mapping node paths from the filter to loaded objects.
  """
    options = options or load_options.LoadOptions()
    if tags is not None and not isinstance(tags, set):
        # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
        # sequences for nest.flatten, so we put those through as-is.
        tags = nest.flatten(tags)
    saved_model_proto, debug_info = (
        loader_impl.parse_saved_model_with_debug_info(export_dir))

    if (len(saved_model_proto.meta_graphs) == 1
            and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
        metrics.IncrementReadApi(_LOAD_V2_LABEL)
        meta_graph_def = saved_model_proto.meta_graphs[0]
        # tensor_content field contains raw bytes in litle endian format
        # which causes problems when loaded on big-endian systems
        # requiring byteswap
        if sys.byteorder == "big":
            saved_model_utils.swap_function_tensor_content(
                meta_graph_def, "little", "big")
        if (tags is not None
                and set(tags) != set(meta_graph_def.meta_info_def.tags)):
            raise ValueError(
                f"Got an incompatible argument to `tags`: {tags}. The SavedModel at "
                f"{export_dir} has one MetaGraph with tags "
                f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, "
                "pass 'None', or pass matching tags.")
        object_graph_proto = meta_graph_def.object_graph_def

        ckpt_options = checkpoint_options.CheckpointOptions(
            experimental_io_device=options.experimental_io_device)
        with ops.init_scope():
            try:
                loader = Loader(object_graph_proto, saved_model_proto,
                                export_dir, ckpt_options, options, filters)
            except errors.NotFoundError as err:
                raise FileNotFoundError(
                    str(err) +
                    "\n You may be trying to load on a different device "
                    "from the computational device. Consider setting the "
                    "`experimental_io_device` option in `tf.saved_model.LoadOptions` "
                    "to the io_device such as '/job:localhost'.")
            root = loader.get(0)
            root.graph_debug_info = loader.adjust_debug_info_func_names(
                debug_info)
        root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        metrics.IncrementRead(write_version="2")
    else:
        if filters:
            raise ValueError(
                "SavedModels saved from Tensorflow 1.x or Estimator (any"
                " version) cannot be loaded with node filters.")
        with ops.init_scope():
            root = load_v1_in_v2.load(export_dir, tags)
            root.graph_debug_info = debug_info

    if filters:
        return {node_id: loader.get(node_id) for node_id in filters}
    else:
        return {"root": root}