def _restore_checkpoint(self): """Load state from checkpoint into the deserialized objects.""" variables_path = saved_model_utils.get_variables_path(self._export_dir) # TODO(andresp): Clean use of private methods of TrackableSaver. # pylint: disable=protected-access saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) with ops.device("CPU"): saver._file_prefix_placeholder = constant_op.constant(variables_path) load_status = saver.restore(variables_path) load_status.assert_existing_objects_matched() checkpoint = load_status._checkpoint # When running in eager mode, the `restore` call above has already run and # restored the state of trackables, call `position.restore_ops()` will # return an empty list as there is nothing left to do. In graph mode, that # will return the list of ops that must run to restore the object on that # position. We have to wire them in the initializers of the objects so that # they get initialized properly when using common practices (e.g. the ones # used by ManagedSession) without further user action. for object_id, obj in dict(checkpoint.object_by_proto_id).items(): position = base.CheckpointPosition(checkpoint=checkpoint, proto_id=object_id) restore_ops = position.restore_ops() if restore_ops: if resource_variable_ops.is_resource_variable(obj): obj._initializer_op = restore_ops else: raise NotImplementedError( ("Missing functionality to restore state of object " "%r from the checkpoint." % obj))
def add_meta_graph_and_variables(self, # pylint: disable=too-many-arguments, arguments-differ sess, tags, signature_def_map=None, assets_collection=None, legacy_init_op=None, clear_devices=False, main_op=None, strip_default_attrs=False, saver=None, train_op=None): """Save graph variables and metagraph.""" # users must provide an autodist saver to use saved_model # We must use the autodist.saver to save variables, but the saver has to be # created before the autodist_distributed_session. In this case, we can't create # an autodist saver automatically for users. assert saver is not None, "An autodist saver must be provided!" if self._has_saved_variables: # pylint: disable=access-member-before-definition raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " "`add_meta_graph()` instead.") signature_def_map = signature_def_map or {} self._validate_signature_def_map(signature_def_map) main_op = main_op or legacy_init_op self._add_collections(assets_collection, main_op, train_op) saved_model_utils.get_or_create_variables_dir(self._export_dir) variables_path = saved_model_utils.get_variables_path(self._export_dir) saver.save(sess, variables_path, write_meta_graph=False, write_state=False) meta_graph_def = saver.export_meta_graph( clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) self._has_saved_variables = True # pylint: disable=attribute-defined-outside-init
def test_export_keras_estimator(self, checkpoint_format): keras_model, (x_train, y_train), ( _, _), train_input_fn, _ = get_resource_for_simple_model( model_type='sequential', is_evaluate=False) keras_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) keras_model.fit(x_train, y_train, epochs=1) bias_value = keras.backend.get_value(keras_model.layers[0].bias) est_keras = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), checkpoint_format=checkpoint_format) def serving_input_receiver_fn(): feature_spec = { 'dense_input': parsing_ops.FixedLenFeature([1], dtype=dtypes.float32) } return export_lib.build_parsing_serving_input_receiver_fn( feature_spec) # Try immediately exporting, testing that (1) exported values are the same, # and (2) estimator can be exported without saving a checkpoint into the # model directory. saved_model_dir = est_keras.export_saved_model( tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn()) variables_path = saved_model_utils.get_variables_path(saved_model_dir) variable_name = 'dense/bias' if checkpoint_format == 'checkpoint': names_to_keys = saver_lib.object_graph_key_mapping(variables_path) variable_name = names_to_keys[variable_name] self.assertAllClose( bias_value, training.load_variable(variables_path, variable_name)) # Export the estimator after training a bit. est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) saved_model_dir = est_keras.export_saved_model( tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn()) variables_path = saved_model_utils.get_variables_path(saved_model_dir) self.assertNotAllClose( bias_value, training.load_variable(variables_path, variable_name))
def _get_saved_model_ckpt(saved_model_dir): """Return path to variables checkpoint in a `SavedModel` directory.""" if not gfile.Exists( os.path.join(saved_model_utils.get_variables_dir(saved_model_dir), compat.as_text('variables.index'))): raise ValueError('Directory provided has an invalid SavedModel format: %s' % saved_model_dir) return saved_model_utils.get_variables_path(saved_model_dir)
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_collection=None, legacy_init_op=None, clear_devices=False, main_op=None, strip_default_attrs=False, saver=None): if self._has_saved_variables: raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " "`add_meta_graph()` instead.") # Validate the signature def map to ensure all included TensorInfos are # properly populated. signature_def_map = signature_def_map or {} self._validate_signature_def_map(signature_def_map) # legacy_init_op is deprecated, and going away in TF 2.0. # Re-mapping to main_op, as treatment is identical regardless. main_op = main_op or legacy_init_op # Add assets and ops self._add_collections(assets_collection, main_op, None) saved_model_utils.get_or_create_variables_dir(self._export_dir) variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a # SavedModel can be copied or moved, this avoids the checkpoint state to # become outdated. saver.save(sess, variables_path, write_meta_graph=False, write_state=False) # Export the meta graph def. # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. meta_graph_def = saver.export_meta_graph( clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def __init__(self, export_dir): """Creates a `SavedModelLoader`. Args: export_dir: Directory in which the SavedModel protocol buffer and variables to be loaded are located. """ self._export_dir = export_dir self._variables_path = saved_model_utils.get_variables_path(export_dir) self._saved_model = parse_saved_model(export_dir)
def __init__(self, export_dir): """Creates a `SavedModelLoader`. Args: export_dir: Directory in which the SavedModel protocol buffer and variables to be loaded are located. """ self._export_dir = export_dir self._variables_path = saved_model_utils.get_variables_path(export_dir) self._saved_model = _parse_saved_model(export_dir)
def _restore_checkpoint(self): """Load state from checkpoint into the deserialized objects.""" variables_path = saved_model_utils.get_variables_path(self._export_dir) # TODO(b/205010730): Clean use of private methods of TrackableSaver. # pylint: disable=protected-access saver = checkpoint.TrackableSaver( graph_view.ObjectGraphView(self.get(0))) with ops.device("CPU"): saver._file_prefix_placeholder = constant_op.constant( variables_path) if self._save_options.allow_partial_checkpoint: load_status = saver.restore( variables_path, self._checkpoint_options).expect_partial() load_status.assert_nontrivial_match() else: load_status = saver.restore(variables_path, self._checkpoint_options) load_status.assert_existing_objects_matched() ckpt = load_status._checkpoint if not context.executing_eagerly(): # When running in eager mode, the `restore` call above has already run and # restored the state of trackables, and calling `position.restore_ops()` # would re-run the restore. In graph mode, that will return a cached list # of ops that must run to restore the object on that position. We have to # wire them in the initializers of the objects so that they get # initialized properly when using common practices (e.g. the ones used by # ManagedSession) without further user action. for object_id, obj in dict(ckpt.object_by_proto_id).items(): position = base.CheckpointPosition(checkpoint=ckpt, proto_id=object_id) registered_saver = position.get_registered_saver_name() if registered_saver: raise NotImplementedError( "Loading a SavedModel that uses registered checkpoint saver is " f"not supported in graph mode. The loaded object {obj} uses the " f"saver registered with the name {registered_saver}.") restore_ops = position.restore_ops() if restore_ops: if resource_variable_ops.is_resource_variable(obj): if len(restore_ops) == 1: obj._initializer_op = restore_ops[0] else: obj._initializer_op = control_flow_ops.group( *restore_ops) elif isinstance(obj, lookup_ops.LookupInterface): # We don't need to check for eager execution here, since this code # path should only be taken if we are restoring in graph mode. ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) else: raise NotImplementedError( f"Unable to restore state of object {obj} from the checkpoint." )
def _export_model_json_and_variables(model, saved_model_path): """Save model variables and json structure into SavedModel subdirectories.""" # Save model configuration as a json string under assets folder. model_json = model.to_json() model_json_filepath = os.path.join( saved_model_utils.get_or_create_assets_dir(saved_model_path), compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) file_io.write_string_to_file(model_json_filepath, model_json) # Save model weights in checkpoint format under variables folder. saved_model_utils.get_or_create_variables_dir(saved_model_path) checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) return checkpoint_prefix
def _restore_checkpoint(self): """Load state from checkpoint into the deserialized objects.""" variables_path = saved_model_utils.get_variables_path(self._export_dir) # TODO(andresp): Clean use of private methods of TrackableSaver. # pylint: disable=protected-access saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) with ops.device("CPU"): saver._file_prefix_placeholder = constant_op.constant( variables_path) if self._expect_partial_checkpoint: load_status = saver.restore( variables_path, self._checkpoint_options).expect_partial() else: load_status = saver.restore(variables_path, self._checkpoint_options) load_status.assert_existing_objects_matched() checkpoint = load_status._checkpoint # When running in eager mode, the `restore` call above has already run and # restored the state of trackables, call `position.restore_ops()` will # return an empty list as there is nothing left to do. In graph mode, that # will return the list of ops that must run to restore the object on that # position. We have to wire them in the initializers of the objects so that # they get initialized properly when using common practices (e.g. the ones # used by ManagedSession) without further user action. for object_id, obj in dict(checkpoint.object_by_proto_id).items(): position = base.CheckpointPosition(checkpoint=checkpoint, proto_id=object_id) restore_ops = position.restore_ops() if restore_ops: if resource_variable_ops.is_resource_variable(obj): if len(restore_ops) == 1: obj._initializer_op = restore_ops[0] else: obj._initializer_op = control_flow_ops.group( *restore_ops) elif isinstance(obj, lookup_ops.LookupInterface): # We don't need to check for eager execution here, since this code # path should only be taken if we are restoring in graph mode. ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) else: raise NotImplementedError( ("Missing functionality to restore state of object " "%r from the checkpoint." % obj))
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example, serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently, variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead of using the calling context's device. This means for example that exporting a model that runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. A single tf.function can generate many ConcreteFunctions. If a downstream tool wants to refer to all concrete functions generated by a single tf.function you can use the `function_aliases` argument to store a map from the alias name to all concrete function names. E.g. ```python class MyModel: @tf.function def func(): ... @tf.function def serve(): ... func() model = MyModel() signatures = { 'serving_default': model.serve.get_concrete_function(), } options = tf.saved_model.SaveOptions(function_aliases={ 'my_func': func, }) tf.saved_model.save(model, export_dir, signatures, options) ``` Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() _, exported_graph, object_saver, asset_info = _build_meta_graph( obj, export_dir, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out # the SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) object_saver.save(utils_impl.get_variables_path(export_dir), options=ckpt_options) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. if context.executing_eagerly(): try: context.async_wait() # Ensure save operations have completed. except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to save on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.SaveOptions " "to the io_device such as '/job:localhost'.") path = os.path.join(compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def export(obj, export_dir, signatures=None): # pylint: disable=line-too-long """Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). The `signatures` argument indicates TensorFlow functions which will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or created without a signature using `@tf.function` and then converted to a concrete TensorFlow function using `f.get_concrete_function(...)`. In either case, `Tensor` inputs to `signatures` functions which are not associated with a unique Python argument name must have names explicitly specified in their `tf.TensorSpec` objects. Cases where this is necessary include positional arguments passed through variadic `*args` and multiple `Tensor` inputs which are part of the same nested structure. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to Tensors, in which case the string keys will be used to name outputs. Exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=tf.TensorSpec(shape=[None], dtype=tf.string)) def serve(serialized): ... m = Model() tf.saved_model.export(m, '/tmp/saved_model/', signatures=m.serve) ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def compute(x): ... m = Model() tf.saved_model.export( m, '/tmp/saved_model/', signatures=m.compute.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Args: obj: A checkpointable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. Raises: ValueError: If `obj` is not checkpointable. """ # pylint: enable=line-too-long if not isinstance(obj, base.CheckpointableBase): raise ValueError( "Expected a Checkpointable object for export, got {}.".format(obj)) object_saver = util.CheckpointableSaver(obj) utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) signatures = _canonicalize_signatures(signatures) graph_def, signatures, saver_def = _make_graph_def(obj, signatures, object_saver) saved_model = saved_model_pb2.SavedModel() saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) meta_graph_def = saved_model.meta_graphs.add() meta_graph_def.saver_def.CopyFrom(saver_def) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. meta_graph_def.graph_def.MergeFrom(graph_def) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].MergeFrom(signature) path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) file_io.write_string_to_file(path, saved_model.SerializeToString())
def save(obj, export_dir, signatures=None): # pylint: disable=line-too-long """Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.train.Checkpoint): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Since `tf.keras.Model` objects are also Checkpointable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. The current implementation of `tf.saved_model.save` targets serving use-cases, but omits information which will be necessary for the planned future implementation of `tf.saved_model.load`. Exported models using the current `save` implementation, and other existing SavedModels, will not be compatible with `tf.saved_model.load` when it is implemented. Further, `save` will in the future attempt to export `@tf.function`-decorated methods which it does not currently inspect, so some objects which are exportable today will raise exceptions on export in the future (e.g. due to complex/non-serializable default arguments). Such backwards-incompatible API changes are expected only prior to the TensorFlow 2.0 release. Args: obj: A checkpointable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. Raises: ValueError: If `obj` is not checkpointable. @compatibility(eager) Not supported when graph building. From TensorFlow 1.x, `tf.enable_eager_execution()` must run first. May not be called from within a function body. @end_compatibility """ if not context.executing_eagerly(): with ops.init_scope(): if context.executing_eagerly(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") else: raise AssertionError( "tf.saved_model.save is not supported when graph building. " "tf.enable_eager_execution() must run first when calling it from " "TensorFlow 1.x.") # pylint: enable=line-too-long if not isinstance(obj, base.CheckpointableBase): raise ValueError( "Expected a Checkpointable object for export, got {}.".format(obj)) if signatures is None: # Note that we run this before saving the checkpoint, since looping over # attributes may have the side effect of creating variables in some cases. signatures = _find_function_to_export(obj) signatures = _canonicalize_signatures(signatures) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.CheckpointableSaver(obj) asset_info = _fill_meta_graph_def(meta_graph_def, obj, signatures, object_saver) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) file_io.write_string_to_file(path, saved_model.SerializeToString()) _write_object_graph(obj, export_dir, asset_info.asset_index)
def _export_model_variables(model, saved_model_path): """Saves model weights in checkpoint format under variables folder.""" saved_model_utils.get_or_create_variables_dir(saved_model_path) checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) return checkpoint_prefix
def _restore_checkpoint(self): variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = util.CheckpointableSaver(self.get(0)) saver.restore(variables_path).assert_consumed()
def save(obj, export_dir, signatures=None): # pylint: disable=line-too-long """Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.train.Checkpoint): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Since `tf.keras.Model` objects are also Checkpointable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. The current implementation of `tf.saved_model.save` targets serving use-cases, but omits information which will be necessary for the planned future implementation of `tf.saved_model.load`. Exported models using the current `save` implementation, and other existing SavedModels, will not be compatible with `tf.saved_model.load` when it is implemented. Further, `save` will in the future attempt to export `@tf.function`-decorated methods which it does not currently inspect, so some objects which are exportable today will raise exceptions on export in the future (e.g. due to complex/non-serializable default arguments). Such backwards-incompatible API changes are expected only prior to the TensorFlow 2.0 release. Args: obj: A checkpointable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. Raises: ValueError: If `obj` is not checkpointable. @compatibility(eager) Not supported when graph building. From TensorFlow 1.x, `tf.enable_eager_execution()` must run first. May not be called from within a function body. @end_compatibility """ if not context.executing_eagerly(): with ops.init_scope(): if context.executing_eagerly(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") else: raise AssertionError( "tf.saved_model.save is not supported when graph building. " "tf.enable_eager_execution() must run first when calling it from " "TensorFlow 1.x.") # pylint: enable=line-too-long if not isinstance(obj, base.Checkpointable): raise ValueError( "Expected a Checkpointable object for export, got {}.".format(obj)) # Use _SaveableView to provide a stable listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(obj) saveable_view = _SaveableView(obj) if signatures is None: signatures = _find_function_to_export(saveable_view) signatures = _canonicalize_signatures(signatures) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() # TODO(andresp): Should this be using saveable_view? object_saver = util.CheckpointableSaver(obj) asset_info = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, object_saver) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) file_io.write_string_to_file(path, saved_model.SerializeToString()) _write_object_graph(saveable_view, export_dir, asset_info.asset_index)
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_list=None, clear_devices=False, init_op=None, train_op=None, strip_default_attrs=False, saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the corresponding meta graph def. This function assumes that the variables to be saved have been initialized. For a given `SavedModelBuilder`, this API must be called exactly once and for the first meta graph to save. For subsequent meta graph defs to be added, the `add_meta_graph()` API must be used. Args: sess: The TensorFlow session from which to save the meta graph and variables. tags: The set of tags with which to save the meta graph. signature_def_map: The map of signature def map to add to the meta graph def. assets_list: Assets to be saved with SavedModel. clear_devices: Set to true if the device info on the default graph should be cleared. init_op: Op or group of ops to execute when the graph is loaded. Note that when the init_op is specified it is run after the restore op at load-time. train_op: Op or group of ops that trains the model when run. This will not be run automatically when the graph is loaded, instead saved in a SignatureDef accessible through the exported MetaGraph. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). saver: An instance of tf.compat.v1.train.Saver that will be used to export the metagraph and save variables. If None, a sharded Saver that restores all variables will be used. """ # pylint: enable=line-too-long if self._has_saved_variables: raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " "`add_meta_graph()` instead.") # Validate the signature def map to ensure all included TensorInfos are # properly populated. signature_def_map = signature_def_map or {} self._validate_signature_def_map(signature_def_map) # Create a SignatureDef pointing to the graph initialization op, which will # be added to the MetaGraphDef. _add_op_to_signature_def_map(signature_def_map, init_op, constants.INIT_OP_SIGNATURE_KEY) _add_op_to_signature_def_map(signature_def_map, train_op, constants.TRAIN_OP_SIGNATURE_KEY) saved_model_utils.get_or_create_variables_dir(self._export_dir) variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a # SavedModel can be copied or moved, this avoids the checkpoint state to # become outdated. saver.save(sess, variables_path, write_meta_graph=False, write_state=False) # Export the meta graph def. # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. meta_graph_def = saver.export_meta_graph( clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Save asset files and write them to disk, if any. self._save_and_write_assets(meta_graph_def, assets_list) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def add_meta_graph_and_variables(self, sess, tags, signature_def_map=None, assets_collection=None, legacy_init_op=None, clear_devices=False, main_op=None, strip_default_attrs=False, saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the corresponding meta graph def. This function assumes that the variables to be saved have been initialized. For a given `SavedModelBuilder`, this API must be called exactly once and for the first meta graph to save. For subsequent meta graph defs to be added, the `add_meta_graph()` API must be used. Args: sess: The TensorFlow session from which to save the meta graph and variables. tags: The set of tags with which to save the meta graph. signature_def_map: The map of signature def map to add to the meta graph def. assets_collection: Assets collection to be saved with SavedModel. legacy_init_op: Legacy support for op or group of ops to execute after the restore op upon a load. Deprecated; please use main_op instead. clear_devices: Set to true if the device info on the default graph should be cleared. main_op: Op or group of ops to execute when the graph is loaded. Note that when the main_op is specified it is run after the restore op at load-time. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). saver: An instance of tf.train.Saver that will be used to export the metagraph and save variables. If None, a sharded Saver that restores all variables will be used. """ # pylint: enable=line-too-long if self._has_saved_variables: raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " "`add_meta_graph()` instead.") # Validate the signature def map to ensure all included TensorInfos are # properly populated. self._validate_signature_def_map(signature_def_map) # legacy_init_op is deprecated, and going away in TF 2.0. # Re-mapping to main_op, as treatment is identical regardless. main_op = main_op or legacy_init_op # Add assets and ops self._add_collections(assets_collection, main_op, None) saved_model_utils.get_or_create_variables_dir(self._export_dir) variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a # SavedModel can be copied or moved, this avoids the checkpoint state to # become outdated. saver.save(sess, variables_path, write_meta_graph=False, write_state=False) # Export the meta graph def. # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. meta_graph_def = saver.export_meta_graph( clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) # Mark this instance of SavedModel as having saved variables, such that # subsequent attempts to save variables will fail. self._has_saved_variables = True
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) options = options or save_options.SaveOptions() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures = signature_serialization.canonicalize_signatures(signatures) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) object_graph_proto = _serialize_object_graph( saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join( utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)