def load(export_dir, tags=None): """Load a SavedModel from `export_dir`. Signatures associated with the SavedModel are available as functions: ```python imported = tf.saved_model.load(path) f = imported.signatures["serving_default"] print(f(x=tf.constant([[1.]]))) ``` Objects exported with `tf.saved_model.save` additionally have trackable objects and functions assigned to attributes: ```python exported = tf.train.Checkpoint(v=tf.Variable(3.)) exported.f = tf.function( lambda x: exported.v * x, input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) tf.saved_model.save(exported, path) imported = tf.saved_model.load(path) assert 3. == imported.v.numpy() assert 6. == imported.f(x=tf.constant(2.)).numpy() ``` Args: export_dir: The SavedModel directory to load from. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.load`. Returns: A trackable object with a `signatures` attribute mapping from signature keys to functions. If the SavedModel was exported by `tf.saved_model.load`, it also points to trackable objects and functions which were attached to the exported object. Raises: ValueError: If `tags` don't match a MetaGraph in the SavedModel. """ if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto = loader_impl.parse_saved_model(export_dir) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( ("The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.") .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def with ops.init_scope(): loader = _Loader(object_graph_proto, saved_model_proto, export_dir) root = loader.get(0) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) return root
def load(export_dir, tags=None): """Load a SavedModel from `export_dir`. Signatures associated with the SavedModel are available as functions: ```python imported = tf.saved_model.load(path) f = imported.signatures["serving_default"] print(f(x=tf.constant([[1.]]))) ``` Objects exported with `tf.saved_model.save` additionally have trackable objects and functions assigned to attributes: ```python exported = tf.train.Checkpoint(v=tf.Variable(3.)) exported.f = tf.function( lambda x: exported.v * x, input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) tf.saved_model.save(exported, path) imported = tf.saved_model.load(path) assert 3. == imported.v.numpy() assert 6. == imported.f(x=tf.constant(2.)).numpy() ``` _Importing SavedModels from TensorFlow 1.x_ SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat graph instead of `tf.function` objects. These SavedModels will have functions corresponding to their signatures in the `.signatures` attribute, but also have a `.prune` method which allows you to extract functions for new subgraphs. This is equivalent to importing the SavedModel and naming feeds and fetches in a Session from TensorFlow 1.x. ```python imported = tf.saved_model.load(path_to_v1_saved_model) pruned = imported.prune("x:0", "out:0") pruned(tf.ones([])) ``` See `tf.compat.v1.wrap_function` for details. These SavedModels also have a `.variables` attribute containing imported variables, and a `.graph` attribute representing the whole imported graph. For SavedModels exported from `tf.saved_model.save`, variables are instead assigned to whichever attributes they were assigned before export. Args: export_dir: The SavedModel directory to load from. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.load`. Returns: A trackable object with a `signatures` attribute mapping from signature keys to functions. If the SavedModel was exported by `tf.saved_model.load`, it also points to trackable objects and functions which were attached to the exported object. Raises: ValueError: If `tags` don't match a MetaGraph in the SavedModel. """ if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto = loader_impl.parse_saved_model(export_dir) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( ("The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.") .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def with ops.init_scope(): loader = _Loader(object_graph_proto, saved_model_proto, export_dir) root = loader.get(0) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) return root
def load_partial(export_dir, filters, tags=None, options=None): """Partially load a SavedModel (saved from V2). Similar to `tf.saved_model.load`, but with an additional argument that lets you specify which nodes to load. `tf.saved_model.load_partial(export_dir, ["root"])` and `tf.saved_model.load(export_dir)` are equivalent. Note: This only works for SavedModels saved with TensorFlow V2 from `tf.saved_model.save` or Keras. This will not load SavedModels save from the Estimator API. In Tensorflow V2, SavedModel stores the **object graph** of the saved object. The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras layers, etc.) and edges that are the name of the attributes connecting the objects. *Example 1* ``` model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) tf.saved_model.save(model, '/tmp/model') loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.child_layer', 'root.child_layer.v']) loaded['root.child_layer'].v.numpy() 5. loaded['root.child_layer'].v is loaded['root.child_layer.v'] True *Example 2* model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) >>> tf.saved_model.save(model, '/tmp/model') # Create a variable new_variable = tf.Variable(0.) loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) loaded['root.child_layer'].v.numpy() 5. new_variable.numpy() 5. ``` **Loading under different distribution strategies** You can load different parts of the model under different distribution strategies. Note that this is very experimental so use with care. ``` model = tf.Module() model.layer_1 = tf.Module() model.layer_1.v = tf.Variable(5.) model.layer_2 = tf.Module() model.layer_2.v = tf.Variable(7.) tf.saved_model.save(model, '/tmp/model') # Load with no strategy loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_1']) loaded['root.layer_1'].v <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> strategy = tf.distribute.MirroredStrategy() with strategy.scope(): ... loaded2 = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_2']) loaded2['root.layer_2'].v MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> } ``` Args: export_dir: The SavedModel directory to load from. filters: A list or dictionary where each element or key is a string path to nodes that should be loaded. Node paths consist of all the child attribute names to reach that node in the form: `root.{attribute_name}`. The loader will load all of the specified nodes and their recursive descendants. When this option is defined, the loader will return a dictionary mapping the node paths to the loaded objects. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.save`. options: `tf.saved_model.LoadOptions` object that specifies options for loading. Returns: A dictionary mapping node paths from the filter to loaded objects. """ options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}
def load(export_dir, tags=None): """Load a SavedModel from `export_dir`. Signatures associated with the SavedModel are available as functions: ```python imported = tf.saved_model.load(path) f = imported.signatures["serving_default"] print(f(x=tf.constant([[1.]]))) ``` Objects exported with `tf.saved_model.save` additionally have trackable objects and functions assigned to attributes: ```python exported = tf.train.Checkpoint(v=tf.Variable(3.)) exported.f = tf.function( lambda x: exported.v * x, input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) tf.saved_model.save(exported, path) imported = tf.saved_model.load(path) assert 3. == imported.v.numpy() assert 6. == imported.f(x=tf.constant(2.)).numpy() ``` Args: export_dir: The SavedModel directory to load from. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.load`. Returns: A trackable object with a `signatures` attribute mapping from signature keys to functions. If the SavedModel was exported by `tf.saved_model.load`, it also points to trackable objects and functions which were attached to the exported object. Raises: ValueError: If `tags` don't match a MetaGraph in the SavedModel. """ if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto = loader_impl.parse_saved_model(export_dir) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError(( "The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.").format( export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def with ops.init_scope(): loader = _Loader(object_graph_proto, saved_model_proto, export_dir) root = loader.get(0) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) return root
def load_tf_checkpoints(model, config, tf_checkpoint_path): print("Building PyTorch model from configuration: {}".format( str(config))) print("Converting TensorFlow checkpoint from {}".format( tf_checkpoint_path)) # Load weights from TF model init_vars = load_v1_in_v2.load(tf_checkpoint_path, tags=["train"]) names = [] arrays = [] n_params = 0 for tf_var in init_vars.variables: name = tf_var.name print("Loading TF weight {} with shape {}".format( name, tf_var.shape)) n_params += np.prod(tf_var.shape) names.append(name) arrays.append(tf_var.numpy()) for name, array in zip(names, arrays): name = re.sub(r"module\/|\:0", "", name).strip() name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any(n in [ "adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step", ] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): scope_names = re.split(r"_(\d+)", m_name) else: scope_names = [m_name] if scope_names[0] == "kernel" or scope_names[0] == "gamma": pointer = getattr(pointer, "weight") elif scope_names[0] == "output_bias" or scope_names[ 0] == "beta": pointer = getattr(pointer, "bias") elif scope_names[0] == "output_weights": pointer = getattr(pointer, "weight") elif scope_names[0] == "squad": pointer = getattr(pointer, "classifier") else: try: pointer = getattr(pointer, scope_names[0]) except AttributeError: print("Skipping {}".format("/".join(name))) continue if len(scope_names) >= 2: num = int(scope_names[1]) pointer = pointer[num] if m_name[-11:] == "_embeddings": pointer = getattr(pointer, "weight") elif m_name == "kernel": array = np.transpose(array) try: assert ( pointer.shape == array.shape ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" except AssertionError as e: e.args += (pointer.shape, array.shape) raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) return model
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, filters=None): """Loader implementation.""" options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( "Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) if isinstance(loader, Loader): root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}
def load(export_dir, tags=None): """Load a SavedModel from `export_dir`. Signatures associated with the SavedModel are available as functions: ```python imported = tf.saved_model.load(path) f = imported.signatures["serving_default"] print(f(x=tf.constant([[1.]]))) ``` Objects exported with `tf.saved_model.save` additionally have trackable objects and functions assigned to attributes: ```python exported = tf.train.Checkpoint(v=tf.Variable(3.)) exported.f = tf.function( lambda x: exported.v * x, input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) tf.saved_model.save(exported, path) imported = tf.saved_model.load(path) assert 3. == imported.v.numpy() assert 6. == imported.f(x=tf.constant(2.)).numpy() ``` _Importing SavedModels from TensorFlow 1.x_ SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat graph instead of `tf.function` objects. These SavedModels will have functions corresponding to their signatures in the `.signatures` attribute, but also have a `.prune` method which allows you to extract functions for new subgraphs. This is equivalent to importing the SavedModel and naming feeds and fetches in a Session from TensorFlow 1.x. ```python imported = tf.saved_model.load(path_to_v1_saved_model) pruned = imported.prune("x:0", "out:0") pruned(tf.ones([])) ``` See `tf.compat.v1.wrap_function` for details. These SavedModels also have a `.variables` attribute containing imported variables, and a `.graph` attribute representing the whole imported graph. For SavedModels exported from `tf.saved_model.save`, variables are instead assigned to whichever attributes they were assigned before export. Args: export_dir: The SavedModel directory to load from. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.load`. Returns: A trackable object with a `signatures` attribute mapping from signature keys to functions. If the SavedModel was exported by `tf.saved_model.load`, it also points to trackable objects and functions which were attached to the exported object. Raises: ValueError: If `tags` don't match a MetaGraph in the SavedModel. """ if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto = loader_impl.parse_saved_model(export_dir) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError(( "The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.").format( export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def with ops.init_scope(): loader = _Loader(object_graph_proto, saved_model_proto, export_dir) root = loader.get(0) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) return root