def load(self, sess, tags, import_scope=None, **saver_kwargs): """Load the MetaGraphDef graph and restore variable values into the session. Args: sess: tf.compat.v1.Session to restore variable values. tags: a set of string tags identifying a MetaGraphDef. import_scope: Optional `string` -- if specified, prepend this string followed by '/' to all loaded tensor names. This scope is applied to tensor instances loaded into the passed session, but it is *not* written through to the static `MetaGraphDef` protocol buffer that is returned. **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. Returns: `MetagraphDef` proto of the graph that was loaded. """ saved_model_proto = parse_saved_model(self._export_dir) metrics.IncrementReadApi(_LOADER_LABEL) with sess.graph.as_default(): saver, _ = self.load_graph(sess.graph, tags, import_scope, **saver_kwargs) self.restore_variables(sess, saver, import_scope) self.run_init_ops(sess, tags, import_scope) meta_graph_def = self.get_meta_graph_def_from_tags(tags) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementRead(write_version="2") else: metrics.IncrementRead(write_version="1") return meta_graph_def
def load(export_dir, tags): """Load a v1-style SavedModel as an object.""" metrics.IncrementReadApi(_LOAD_V1_V2_LABEL) loader = _EagerSavedModelLoader(export_dir) result = loader.load(tags=tags) metrics.IncrementRead(write_version="1") return result
def test_SM_increment_read(self): self.assertEqual(metrics.GetRead(write_version="2"), 0) metrics.IncrementReadApi("bar") self.assertEqual(metrics.GetReadApi("bar"), 1) metrics.IncrementRead(write_version="2") self.assertEqual(metrics.GetRead(write_version="2"), 1)
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, filters=None): """Loader implementation.""" options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content( meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( "Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) if isinstance(loader, Loader): root.graph_debug_info = loader.adjust_debug_info_func_names( debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError( "SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}
def load_partial(export_dir, filters, tags=None, options=None): """Partially load a SavedModel (saved from V2). Similar to `tf.saved_model.load`, but with an additional argument that lets you specify which nodes to load. `tf.saved_model.load_partial(export_dir, ["root"])` and `tf.saved_model.load(export_dir)` are equivalent. Note: This only works for SavedModels saved with TensorFlow V2 from `tf.saved_model.save` or Keras. This will not load SavedModels save from the Estimator API. In Tensorflow V2, SavedModel stores the **object graph** of the saved object. The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras layers, etc.) and edges that are the name of the attributes connecting the objects. *Example 1* ``` model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) tf.saved_model.save(model, '/tmp/model') loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.child_layer', 'root.child_layer.v']) loaded['root.child_layer'].v.numpy() 5. loaded['root.child_layer'].v is loaded['root.child_layer.v'] True *Example 2* model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) >>> tf.saved_model.save(model, '/tmp/model') # Create a variable new_variable = tf.Variable(0.) loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) loaded['root.child_layer'].v.numpy() 5. new_variable.numpy() 5. ``` **Loading under different distribution strategies** You can load different parts of the model under different distribution strategies. Note that this is very experimental so use with care. ``` model = tf.Module() model.layer_1 = tf.Module() model.layer_1.v = tf.Variable(5.) model.layer_2 = tf.Module() model.layer_2.v = tf.Variable(7.) tf.saved_model.save(model, '/tmp/model') # Load with no strategy loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_1']) loaded['root.layer_1'].v <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> strategy = tf.distribute.MirroredStrategy() with strategy.scope(): ... loaded2 = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_2']) loaded2['root.layer_2'].v MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> } ``` Args: export_dir: The SavedModel directory to load from. filters: A list or dictionary where each element or key is a string path to nodes that should be loaded. Node paths consist of all the child attribute names to reach that node in the form: `root.{attribute_name}`. The loader will load all of the specified nodes and their recursive descendants. When this option is defined, the loader will return a dictionary mapping the node paths to the loaded objects. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.save`. options: `tf.saved_model.LoadOptions` object that specifies options for loading. Returns: A dictionary mapping node paths from the filter to loaded objects. """ options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}