Esempio n. 1
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(andresp): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        load_status = saver.restore(variables_path)
        load_status.assert_existing_objects_matched()
        checkpoint = load_status._checkpoint

        # When running in eager mode, the `restore` call above has already run and
        # restored the state of trackables, call `position.restore_ops()` will
        # return an empty list as there is nothing left to do. In graph mode, that
        # will return the list of ops that must run to restore the object on that
        # position. We have to wire them in the initializers of the objects so that
        # they get initialized properly when using common practices (e.g. the ones
        # used by ManagedSession) without further user action.
        for object_id, obj in dict(checkpoint.object_by_proto_id).items():
            position = base.CheckpointPosition(checkpoint=checkpoint,
                                               proto_id=object_id)
            restore_ops = position.restore_ops()
            if restore_ops:
                if resource_variable_ops.is_resource_variable(obj):
                    obj._initializer_op = restore_ops
                else:
                    raise NotImplementedError(
                        ("Missing functionality to restore state of object "
                         "%r from the checkpoint." % obj))
Esempio n. 2
0
def _build_meta_graph(obj,
                      export_dir,
                      signatures,
                      options,
                      meta_graph_def=None):
    """Creates a MetaGraph containing the resources and functions of an object."""
    if ops.inside_function():
        raise AssertionError(
            "tf.saved_model.save is not supported inside a traced "
            "@tf.function. Move the call to the outer eagerly-executed "
            "context.")
    # pylint: enable=line-too-long
    if not isinstance(obj, base.Trackable):
        raise ValueError(
            "Expected a Trackable object for export, got {}.".format(obj))
    meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()

    checkpoint_graph_view = _AugmentedGraphView(obj)
    if signatures is None:
        signatures = signature_serialization.find_function_to_export(
            checkpoint_graph_view)

    signatures, wrapped_functions = (
        signature_serialization.canonicalize_signatures(signatures))
    signature_serialization.validate_saveable_view(checkpoint_graph_view)
    signature_map = signature_serialization.create_signature_map(signatures)
    checkpoint_graph_view.add_object(
        parent_node=checkpoint_graph_view.root,
        name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
        subgraph_root=signature_map)

    # Use _SaveableView to provide a frozen listing of properties and functions.
    # Note we run this twice since, while constructing the view the first time
    # there can be side effects of creating variables.
    _ = _SaveableView(checkpoint_graph_view)
    saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
    object_saver = util.TrackableSaver(checkpoint_graph_view)
    asset_info, exported_graph = _fill_meta_graph_def(
        meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
    if options.function_aliases:
        function_aliases = meta_graph_def.meta_info_def.function_aliases
        for alias, func in options.function_aliases.items():
            for fdef in func._stateful_fn._function_cache.all_values():  # pylint: disable=protected-access
                function_aliases[fdef.name] = alias
            for fdef in func._stateless_fn._function_cache.all_values():  # pylint: disable=protected-access
                function_aliases[fdef.name] = alias

    object_graph_proto = _serialize_object_graph(saveable_view,
                                                 asset_info.asset_index)
    meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)

    # Save debug info, if requested.
    if options.save_debug_info:
        graph_debug_info = _export_debug_info(exported_graph)
        file_io.atomic_write_string_to_file(
            os.path.join(utils_impl.get_or_create_debug_dir(export_dir),
                         constants.DEBUG_INFO_FILENAME_PB),
            graph_debug_info.SerializeToString(deterministic=True))

    return meta_graph_def, exported_graph, object_saver, asset_info
Esempio n. 3
0
 def testLoadFromNameBasedSaver(self):
     """Save a name-based checkpoint, load it using the object-based API."""
     with test_util.device(use_gpu=True):
         save_path = self._write_name_based_checkpoint()
         root = self._initialized_model()
         self._set_sentinels(root)
         with self.assertRaises(AssertionError):
             self._check_sentinels(root)
         object_saver = util.TrackableSaver(
             graph_view.ObjectGraphView(root))
         self._set_sentinels(root)
         status = object_saver.restore(save_path)
         if context.executing_eagerly():
             self._check_sentinels(root)
         if context.executing_eagerly():
             status.assert_consumed()
         else:
             # When graph building, we haven't read any keys, so we don't know
             # whether the restore will be complete.
             with self.assertRaisesRegexp(AssertionError, "not restored"):
                 status.assert_consumed()
         status.run_restore_ops()
         self._check_sentinels(root)
         self._set_sentinels(root)
         status = object_saver.restore(save_path)
         status.initialize_or_restore()
         self._check_sentinels(root)
Esempio n. 4
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(b/205010730): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        if self._save_options.allow_partial_checkpoint:
            load_status = saver.restore(
                variables_path, self._checkpoint_options).expect_partial()
            load_status.assert_nontrivial_match()
        else:
            load_status = saver.restore(variables_path,
                                        self._checkpoint_options)
            load_status.assert_existing_objects_matched()
        checkpoint = load_status._checkpoint

        if not context.executing_eagerly():
            # When running in eager mode, the `restore` call above has already run and
            # restored the state of trackables, and calling `position.restore_ops()`
            # would re-run the restore. In graph mode, that will return a cached list
            # of ops that must run to restore the object on that position. We have to
            # wire them in the initializers of the objects so that they get
            # initialized properly when using common practices (e.g. the ones used by
            # ManagedSession) without further user action.
            for object_id, obj in dict(checkpoint.object_by_proto_id).items():
                position = base.CheckpointPosition(checkpoint=checkpoint,
                                                   proto_id=object_id)
                registered_saver = position.get_registered_saver_name()
                if registered_saver:
                    raise NotImplementedError(
                        "Loading a SavedModel that uses registered checkpoint saver is "
                        f"not supported in graph mode. The loaded object {obj} uses the "
                        f"saver registered with the name {registered_saver}.")

                restore_ops = position.restore_ops()
                if restore_ops:
                    if resource_variable_ops.is_resource_variable(obj):
                        if len(restore_ops) == 1:
                            obj._initializer_op = restore_ops[0]
                        else:
                            obj._initializer_op = control_flow_ops.group(
                                *restore_ops)
                    elif isinstance(obj, lookup_ops.LookupInterface):
                        # We don't need to check for eager execution here, since this code
                        # path should only be taken if we are restoring in graph mode.
                        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                              restore_ops)
                    else:
                        raise NotImplementedError(
                            f"Unable to restore state of object {obj} from the checkpoint."
                        )
Esempio n. 5
0
 def testLoadFromNameBasedSaver(self):
     """Save a name-based checkpoint, load it using the object-based API."""
     with test_util.device(use_gpu=True):
         with self.test_session():
             save_path = self._write_name_based_checkpoint()
             root = self._initialized_model()
             self._set_sentinels(root)
             with self.assertRaises(AssertionError):
                 self._check_sentinels(root)
             object_saver = trackable_utils.TrackableSaver(
                 graph_view.ObjectGraphView(root))
             self._set_sentinels(root)
             status = object_saver.restore(save_path)
             if context.executing_eagerly():
                 self._check_sentinels(root)
             if context.executing_eagerly():
                 status.assert_consumed()
                 status.assert_existing_objects_matched()
                 status.assert_nontrivial_match()
             else:
                 # When graph building, we haven't read any keys, so we don't know
                 # whether the restore will be complete.
                 with self.assertRaisesRegex(AssertionError,
                                             "not restored"):
                     status.assert_consumed()
                 with self.assertRaisesRegex(AssertionError,
                                             "not restored"):
                     status.assert_existing_objects_matched()
                 with self.assertRaisesRegex(AssertionError,
                                             "not restored"):
                     status.assert_nontrivial_match()
             status.run_restore_ops()
             self._check_sentinels(root)
             self._set_sentinels(root)
             status = object_saver.restore(save_path)
             status.initialize_or_restore()
             status.assert_nontrivial_match()
             self._check_sentinels(root)
             # Check that there is no error when keys are missing from the name-based
             # checkpoint.
             root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable(
                 [1.])
             status = object_saver.restore(save_path)
             with self.assertRaises(AssertionError):
                 status.assert_existing_objects_matched()
Esempio n. 6
0
    def _restore_checkpoint(self):
        """Load state from checkpoint into the deserialized objects."""
        variables_path = saved_model_utils.get_variables_path(self._export_dir)
        # TODO(andresp): Clean use of private methods of TrackableSaver.
        # pylint: disable=protected-access
        saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
        with ops.device("CPU"):
            saver._file_prefix_placeholder = constant_op.constant(
                variables_path)
        if self._expect_partial_checkpoint:
            load_status = saver.restore(
                variables_path, self._checkpoint_options).expect_partial()
        else:
            load_status = saver.restore(variables_path,
                                        self._checkpoint_options)
        load_status.assert_existing_objects_matched()
        checkpoint = load_status._checkpoint

        # When running in eager mode, the `restore` call above has already run and
        # restored the state of trackables, call `position.restore_ops()` will
        # return an empty list as there is nothing left to do. In graph mode, that
        # will return the list of ops that must run to restore the object on that
        # position. We have to wire them in the initializers of the objects so that
        # they get initialized properly when using common practices (e.g. the ones
        # used by ManagedSession) without further user action.
        for object_id, obj in dict(checkpoint.object_by_proto_id).items():
            position = base.CheckpointPosition(checkpoint=checkpoint,
                                               proto_id=object_id)
            restore_ops = position.restore_ops()
            if restore_ops:
                if resource_variable_ops.is_resource_variable(obj):
                    if len(restore_ops) == 1:
                        obj._initializer_op = restore_ops[0]
                    else:
                        obj._initializer_op = control_flow_ops.group(
                            *restore_ops)
                elif isinstance(obj, lookup_ops.LookupInterface):
                    # We don't need to check for eager execution here, since this code
                    # path should only be taken if we are restoring in graph mode.
                    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                          restore_ops)
                else:
                    raise NotImplementedError(
                        ("Missing functionality to restore state of object "
                         "%r from the checkpoint." % obj))
Esempio n. 7
0
def save(obj, export_dir, signatures=None, options=None):
  # pylint: disable=line-too-long
  """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).

  Example usage:

  ```python
  class Adder(tf.Module):

    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
    def add(self, x):
      return x + x + 1.

  to_export = Adder()
  tf.saved_model.save(to_export, '/tmp/adder')
  ```

  The resulting SavedModel is then servable with an input named "x", its value
  having any shape and dtype float32.

  The optional `signatures` argument controls which methods in `obj` will be
  available to programs which consume `SavedModel`s, for example serving
  APIs. Python functions may be decorated with
  `@tf.function(input_signature=...)` and passed as signatures directly, or
  lazily with a call to `get_concrete_function` on the method decorated with
  `@tf.function`.

  If the `signatures` argument is omitted, `obj` will be searched for
  `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
  method will be used as the default signature for the SavedModel. This behavior
  is expected to change in the future, when a corresponding
  `tf.saved_model.load` symbol is added. At that point signatures will be
  completely optional, and any `@tf.function` attached to `obj` or its
  dependencies will be exported for use with `load`.

  When invoking a signature in an exported SavedModel, `Tensor` arguments are
  identified by name. These names will come from the Python function's argument
  names by default. They may be overridden by specifying a `name=...` argument
  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
  multiple `Tensor`s are passed through a single argument to the Python
  function.

  The outputs of functions used as `signatures` must either be flat lists, in
  which case outputs will be numbered, or a dictionary mapping string keys to
  `Tensor`, in which case the keys will be used to name outputs.

  Signatures are available in objects returned by `tf.saved_model.load` as a
  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
  on an object with a custom `.signatures` attribute will raise an exception.

  Since `tf.keras.Model` objects are also Trackable, this function can be
  used to export Keras models. For example, exporting with a signature
  specified:

  ```python
  class Model(tf.keras.Model):

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
    def serve(self, serialized):
      ...

  m = Model()
  tf.saved_model.save(m, '/tmp/saved_model/')
  ```

  Exporting from a function without a fixed signature:

  ```python
  class Model(tf.keras.Model):

    @tf.function
    def call(self, x):
      ...

  m = Model()
  tf.saved_model.save(
      m, '/tmp/saved_model/',
      signatures=m.call.get_concrete_function(
          tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
  ```

  `tf.keras.Model` instances constructed from inputs and outputs already have a
  signature and so do not require a `@tf.function` decorator or a `signatures`
  argument. If neither are specified, the model's forward pass is exported.

  ```python
  x = input_layer.Input((4,), name="x")
  y = core.Dense(5, name="out")(x)
  model = training.Model(x, y)
  tf.saved_model.save(model, '/tmp/saved_model/')
  # The exported SavedModel takes "x" with shape [None, 4] and returns "out"
  # with shape [None, 5]
  ```

  Variables must be tracked by assigning them to an attribute of a tracked
  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
  uses, and an exported `Checkpoint` object may be restored as a training
  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
  "variables/" subdirectory. Currently variables are the only stateful objects
  supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
  in the future.

  `tf.function` does not hard-code device annotations from outside the function
  body, instead using the calling context's device. This means for example that
  exporting a model which runs on a GPU and serving it on a CPU will generally
  work, with some exceptions. `tf.device` annotations inside the body of the
  function will be hard-coded in the exported model; this type of annotation is
  discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with
  device-specific layouts, may cause issues. Currently a `DistributionStrategy`
  is another exception: active distribution strategies will cause device
  placements to be hard-coded in a function. Exporting a single-device
  computation and importing under a `DistributionStrategy` is not currently
  supported, but may be in the future.

  SavedModels exported with `tf.saved_model.save` [strip default-valued
  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
  automatically, which removes one source of incompatibilities when the consumer
  of a SavedModel is running an older TensorFlow version than the
  producer. There are however other sources of incompatibilities which are not
  handled automatically, such as when the exported model contains operations
  which the consumer does not have definitions for.

  Args:
    obj: A trackable object to export.
    export_dir: A directory in which to write the SavedModel.
    signatures: Optional, either a `tf.function` with an input signature
      specified or the result of `f.get_concrete_function` on a
      `@tf.function`-decorated function `f`, in which case `f` will be used to
      generate a signature for the SavedModel under the default serving
      signature key. `signatures` may also be a dictionary, in which case it
      maps from signature keys to either `tf.function` instances with input
      signatures or concrete functions. The keys of such a dictionary may be
      arbitrary strings, but will typically be from the
      `tf.saved_model.signature_constants` module.
    options: Optional, `tf.saved_model.SaveOptions` object that specifies
      options for saving.

  Raises:
    ValueError: If `obj` is not trackable.

  @compatibility(eager)
  Not well supported when graph building. From TensorFlow 1.x,
  `tf.compat.v1.enable_eager_execution()` should run first. Calling
  tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
  add new save operations to the default graph each iteration.

  May not be called from within a function body.
  @end_compatibility
  """
  if ops.inside_function():
    raise AssertionError(
        "tf.saved_model.save is not supported inside a traced "
        "@tf.function. Move the call to the outer eagerly-executed "
        "context.")
  # pylint: enable=line-too-long
  if not isinstance(obj, base.Trackable):
    raise ValueError(
        "Expected a Trackable object for export, got {}.".format(obj))
  options = options or save_options.SaveOptions()

  checkpoint_graph_view = _AugmentedGraphView(obj)
  if signatures is None:
    signatures = signature_serialization.find_function_to_export(
        checkpoint_graph_view)

  signatures = signature_serialization.canonicalize_signatures(signatures)
  signature_serialization.validate_saveable_view(checkpoint_graph_view)
  signature_map = signature_serialization.create_signature_map(signatures)
  checkpoint_graph_view.add_object(
      parent_node=checkpoint_graph_view.root,
      name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
      subgraph_root=signature_map)

  # Use _SaveableView to provide a frozen listing of properties and functions.
  # Note we run this twice since, while constructing the view the first time
  # there can be side effects of creating variables.
  _ = _SaveableView(checkpoint_graph_view)
  saveable_view = _SaveableView(checkpoint_graph_view)

  # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
  # compatible (no sessions) and share it with this export API rather than
  # making a SavedModel proto and writing it directly.
  saved_model = saved_model_pb2.SavedModel()
  meta_graph_def = saved_model.meta_graphs.add()
  object_saver = util.TrackableSaver(checkpoint_graph_view)
  asset_info, exported_graph = _fill_meta_graph_def(
      meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
  saved_model.saved_model_schema_version = (
      constants.SAVED_MODEL_SCHEMA_VERSION)
  # So far we've just been generating protocol buffers with no I/O. Now we write
  # the checkpoint, copy assets into the assets directory, and write out the
  # SavedModel proto itself.
  utils_impl.get_or_create_variables_dir(export_dir)
  object_saver.save(utils_impl.get_variables_path(export_dir))
  builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
                                              export_dir)
  path = os.path.join(
      compat.as_str(export_dir),
      compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
  object_graph_proto = _serialize_object_graph(
      saveable_view, asset_info.asset_index)
  meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)

  # Save debug info, if requested.
  if options.save_debug_info:
    graph_debug_info = _export_debug_info(exported_graph)
    file_io.atomic_write_string_to_file(
        os.path.join(
            utils_impl.get_or_create_debug_dir(export_dir),
            constants.DEBUG_INFO_FILENAME_PB),
        graph_debug_info.SerializeToString(deterministic=True))

  # Note that this needs to be the last file operation when saving the
  # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
  # indication that the SavedModel is completely written.
  file_io.atomic_write_string_to_file(
      path, saved_model.SerializeToString(deterministic=True))

  # Clean reference cycles so repeated export()s don't make work for the garbage
  # collector. Before this point we need to keep references to captured
  # constants in the saved graph.
  ops.dismantle_graph(exported_graph)