Ejemplo n.º 1
0
  def load(self, sess, tags, import_scope=None, **saver_kwargs):
    """Load the MetaGraphDef graph and restore variable values into the session.

    Args:
      sess: tf.compat.v1.Session to restore variable values.
      tags: a set of string tags identifying a MetaGraphDef.
      import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.
      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.

    Returns:
      `MetagraphDef` proto of the graph that was loaded.
    """
    saved_model_proto = parse_saved_model(self._export_dir)
    if (len(saved_model_proto.meta_graphs) == 1 and
        saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
      metrics.IncrementReadApi(_LOADER_LABEL, write_version="2")
    else:
      metrics.IncrementReadApi(_LOADER_LABEL, write_version="1")

    with sess.graph.as_default():
      saver, _ = self.load_graph(sess.graph, tags, import_scope,
                                 **saver_kwargs)
      self.restore_variables(sess, saver, import_scope)
      self.run_init_ops(sess, tags, import_scope)
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    metrics.IncrementRead()
    return meta_graph_def
Ejemplo n.º 2
0
def load(export_dir, tags=None, options=None):
    """Load a SavedModel from `export_dir`.

  Signatures associated with the SavedModel are available as functions:

  ```python
  imported = tf.saved_model.load(path)
  f = imported.signatures["serving_default"]
  print(f(x=tf.constant([[1.]])))
  ```

  Objects exported with `tf.saved_model.save` additionally have trackable
  objects and functions assigned to attributes:

  ```python
  exported = tf.train.Checkpoint(v=tf.Variable(3.))
  exported.f = tf.function(
      lambda x: exported.v * x,
      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
  tf.saved_model.save(exported, path)
  imported = tf.saved_model.load(path)
  assert 3. == imported.v.numpy()
  assert 6. == imported.f(x=tf.constant(2.)).numpy()
  ```

  _Loading Keras models_

  Keras models are trackable, so they can be saved to SavedModel. The object
  returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
  `.fit`, `.predict`, etc. methods). A few attributes and functions are still
  available: `.variables`, `.trainable_variables` and `.__call__`.

  ```python
  model = tf.keras.Model(...)
  tf.saved_model.save(model, path)
  imported = tf.saved_model.load(path)
  outputs = imported(inputs)
  ```

  Use `tf.keras.models.load_model` to restore the Keras model.

  _Importing SavedModels from TensorFlow 1.x_

  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
  graph instead of `tf.function` objects. These SavedModels will be loaded with
  the following attributes:

  * `.signatures`: A dictionary mapping signature names to functions.
  * `.prune(feeds, fetches) `: A method which allows you to extract
    functions for new subgraphs. This is equivalent to importing the SavedModel
    and naming feeds and fetches in a Session from TensorFlow 1.x.

    ```python
    imported = tf.saved_model.load(path_to_v1_saved_model)
    pruned = imported.prune("x:0", "out:0")
    pruned(tf.ones([]))
    ```

    See `tf.compat.v1.wrap_function` for details.
  * `.variables`: A list of imported variables.
  * `.graph`: The whole imported graph.
  * `.restore(save_path)`: A function that restores variables from a checkpoint
    saved from `tf.compat.v1.Saver`.

  _Consuming SavedModels asynchronously_

  When consuming SavedModels asynchronously (the producer is a separate
  process), the SavedModel directory will appear before all files have been
  written, and `tf.saved_model.load` will fail if pointed at an incomplete
  SavedModel. Rather than checking for the directory, check for
  "saved_model_dir/saved_model.pb". This file is written atomically as the last
  `tf.saved_model.save` file operation.

  Args:
    export_dir: The SavedModel directory to load from.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.save`.
    options: `tf.saved_model.LoadOptions` object that specifies options for
      loading.

  Returns:
    A trackable object with a `signatures` attribute mapping from signature
    keys to functions. If the SavedModel was exported by `tf.saved_model.save`,
    it also points to trackable objects, functions, debug info which it has been
    saved.

  Raises:
    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
  """
    metrics.IncrementReadApi(_LOAD_V2_LABEL)
    result = load_internal(export_dir, tags, options)["root"]
    metrics.IncrementRead()
    return result
Ejemplo n.º 3
0
 def test_increment_read(self):
   self.assertEqual(metrics.GetRead(), 0)
   metrics.IncrementReadApi("bar")
   self.assertEqual(metrics.GetReadApi("bar"), 1)
   metrics.IncrementRead()
   self.assertEqual(metrics.GetRead(), 1)
Ejemplo n.º 4
0
def load(export_dir, tags):
    """Load a v1-style SavedModel as an object."""
    metrics.IncrementReadApi(_LOAD_V1_V2_LABEL)
    loader = _EagerSavedModelLoader(export_dir)
    return loader.load(tags=tags)
Ejemplo n.º 5
0
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
                  filters=None):
  """Loader implementation."""
  options = options or load_options.LoadOptions()
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto, debug_info = (
      loader_impl.parse_saved_model_with_debug_info(export_dir))

  if (len(saved_model_proto.meta_graphs) == 1 and
      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    metrics.IncrementReadApi(_LOAD_V2_LABEL, write_version="2")
    meta_graph_def = saved_model_proto.meta_graphs[0]
    # tensor_content field contains raw bytes in litle endian format
    # which causes problems when loaded on big-endian systems
    # requiring byteswap
    if sys.byteorder == "big":
      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
                                                     "big")
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def

    ckpt_options = checkpoint_options.CheckpointOptions(
        experimental_io_device=options.experimental_io_device)
    with ops.init_scope():
      try:
        loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
                            ckpt_options, options, filters)
      except errors.NotFoundError as err:
        raise FileNotFoundError(
            str(err) + "\n If trying to load on a different device from the "
            "computational device, consider using setting the "
            "`experimental_io_device` option on tf.saved_model.LoadOptions "
            "to the io_device such as '/job:localhost'."
        )
      root = loader.get(0)
      if isinstance(loader, Loader):
        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
  else:
    if filters:
      raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any "
                       "version) cannot be loaded with node filters.")
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
      root.graph_debug_info = debug_info

  if filters:
    return {node_id: loader.get(node_id) for node_id in filters}
  else:
    return {"root": root}