def test_save_variable_devices(self, save_devices, meta_graph_only): context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: context.context().set_logical_device_configuration( cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) context.ensure_initialized() root = tracking.AutoTrackable() with ops.device("CPU:0"): root.v0 = variables.Variable(1., name="v0") with ops.device("CPU:1"): root.v1 = variables.Variable(1., name="v1") options = save_options.SaveOptions( experimental_variable_policy=save_devices) file_name = os.path.join(self.get_temp_dir(), "saved_model") if meta_graph_only: save.export_meta_graph(obj=root, filename=file_name, options=options) else: save.save(obj=root, export_dir=file_name, options=options) meta = None if meta_graph_only: meta = meta_graph.read_meta_graph_file(file_name) else: meta = loader_impl.parse_saved_model(file_name).meta_graphs[0] # Check devices in meta graph nodes. graph_def = meta.graph_def v0 = next((n for n in graph_def.node if n.name == "v0"), None) v1 = next((n for n in graph_def.node if n.name == "v1"), None) self.assertIsNotNone(v0) self.assertIsNotNone(v1) if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES: self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) else: self.assertEmpty(v0.device) self.assertEmpty(v1.device) # Check devices in object graph nodes. object_graph_def = meta.object_graph_def v0 = next((n.variable for n in object_graph_def.nodes if n.HasField("variable") and n.variable.name == "v0"), None) v1 = next((n.variable for n in object_graph_def.nodes if n.HasField("variable") and n.variable.name == "v1"), None) self.assertIsNotNone(v0) self.assertIsNotNone(v1) if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES: self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) else: self.assertEmpty(v0.device) self.assertEmpty(v1.device)
def test_save_invalid_custom_gradients(self): # The full custom gradients test is in load_test.py. This test just makes # sure that a warning is logged when the user has disabled custom gradients # and saves a model with invalid custom gradients. @custom_gradient.custom_gradient def invalid(x): def grad(dy): raise NotImplementedError return x, grad root = tracking.AutoTrackable() root.f = def_function.function( invalid, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") options = save_options.SaveOptions(experimental_custom_gradients=None) with self.assertLogs(level="WARNING") as logs: save.save(root, save_dir, root.f, options=options) self.assertIn( "Your model contains untraceable custom gradients. This " "warning may become an error in the future. Please set the option " "tf.saved_model.SaveOption(experimental_custom_gradients=True) to get " "the full error message.", "".join(logs.output))
def test_save_debug_info_enabled(self): root = tracking.AutoTrackable() root.f = def_function.function( lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"), input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save( root, save_dir, root.f, options=save_options.SaveOptions(save_debug_info=True)) debug_info_file_name = os.path.join(save_dir, "debug", "saved_model_debug_info.pb") self.assertTrue(os.path.exists(debug_info_file_name)) debug_info = graph_debug_info_pb2.GraphDebugInfo() with open(debug_info_file_name, "rb") as f: debug_info.ParseFromString(f.read()) # Verify that there is a trace for DEBUG_INFO_OP just to ensure that # function debug info tracing is nominally functioning. found_op = False for key in debug_info.traces.keys(): if key.startswith("DEBUG_INFO_OP@"): found_op = True break self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace")
def testTrtGraphConverterV2_SaveWithOptions(self): """Test to make sure that save method respects options kwarg.""" # Create a model and save it. input_saved_model_dir = self.mkdtemp() root = self._GetModelForV2() save.save(root, input_saved_model_dir, {_SAVED_MODEL_SIGNATURE_KEY: root.run}) # Run TRT conversion. converter = self._CreateConverterV2(input_saved_model_dir) converter.convert() # Patch save function with mock. with mock.patch.object(trt_convert, "save") as mock_save: mock_save.save = mock.MagicMock() # Save converted model with options. output_saved_model_dir = self.mkdtemp() options = save_options.SaveOptions(save_debug_info=True) converter.save(output_saved_model_dir, options=options) # Assert that the saved_model.save function was called with the given # save_options by TrtGraphConverterV2.save method. mock_save.save.assert_called_once_with(mock.ANY, mock.ANY, mock.ANY, options=options)
def test_expand_distributed_variables(self, expand_strategy, policy): # 1. Create a context with both CPU:0 and CPU:1. context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: context.context().set_logical_device_configuration( cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) context.ensure_initialized() # 2. Create and save a model under a mirrored strategy. file_name = os.path.join(self.get_temp_dir(), "saved_model.pb") strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]) strategy.extended._use_var_policy = policy with strategy.scope(): root = tracking.AutoTrackable() root.v = variables.Variable([1., 1.], name="v") @def_function.function(input_signature=[]) def f(): root.v.assign([2., 2.]) root.f = f save.export_meta_graph( obj=root, filename=file_name, options=save_options.SaveOptions( experimental_variable_policy=expand_strategy)) # 3. Read the output file and test behavior. meta_graph_def = meta_graph.read_meta_graph_file(file_name) object_graph = meta_graph_def.object_graph_def graph_def = meta_graph_def.graph_def v = next((n.variable for n in object_graph.nodes if n.HasField("variable") and n.variable.name == "v"), None) saved_function = next((f for f in graph_def.library.function if "inference_f_" in f.signature.name), None) self.assertIsNotNone(saved_function) if (expand_strategy == save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES): # experimental_save_variable_devices should have been automatically set. self.assertIn("CPU:0", v.device) components = v.experimental_distributed_variable_components self.assertLen(components, 2) v0 = next((x for x in components if x.name == "v"), None) v1 = next((x for x in components if x.name == "v/replica_1"), None) self.assertIsNotNone(v0) self.assertIsNotNone(v1) self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) self.assertLen(saved_function.signature.input_arg, 2) else: self.assertEmpty(v.device) self.assertEmpty(v.experimental_distributed_variable_components) self.assertLen(saved_function.signature.input_arg, 1)
def test_expand_distributed_variables_not_allowed(self): root = tracking.AutoTrackable() with self.assertRaisesRegex(NotImplementedError, "not implemented in saved_model.save"): save.save(obj=root, export_dir="", options=save_options.SaveOptions( experimental_variable_policy=save_options. VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES))
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') v2_compat.enable_v2_behavior() save.save(ToyModule(), FLAGS.saved_model_path, options=save_options.SaveOptions(save_debug_info=False)) logging.info('Saved model to: %s', FLAGS.saved_model_path)
def testIgnorePackedVariableInSaveContext(self, distribution): distribution._enable_packed_variable_in_eager_mode = True with distribution.scope(): v = variables_lib.Variable(0) self.assertIsInstance(v._packed_variable, packed.PackedDistributedVariable) options = save_options.SaveOptions() with save_context.save_context(options): self.assertIsNone(v._packed_variable)
def testRetraceOnSavingFirstTraceOutsideScope(self, distribution): with distribution.scope(): v = variables.Variable(0.) tracing_count = [0] @def_function.function def func(): tracing_count[0] += 1 return v + 1. func() prev_tracing_count = tracing_count[0] with save_context.save_context(save_options.SaveOptions()): func() self.assertEqual(prev_tracing_count + 1, tracing_count[0]) prev_tracing_count = tracing_count[0] with save_context.save_context(save_options.SaveOptions()): func() self.assertEqual(prev_tracing_count, tracing_count[0])
def test_accepts_variable_policy(self): options = save_options.SaveOptions() self.assertEqual(save_options.VariablePolicy.NONE, options.experimental_variable_policy) # VariablePolicy instances. options = save_options.SaveOptions(experimental_variable_policy=save_options .VariablePolicy.SAVE_VARIABLE_DEVICES) self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, options.experimental_variable_policy) options = save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy .EXPAND_DISTRIBUTED_VARIABLES) self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, options.experimental_variable_policy) # String conversions. options = save_options.SaveOptions( experimental_variable_policy="save_variable_devices") self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, options.experimental_variable_policy) options = save_options.SaveOptions( experimental_variable_policy="expand_distributed_variables") self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, options.experimental_variable_policy) with self.assertRaisesRegex(ValueError, "Invalid VariablePolicy value"): options = save_options.SaveOptions( experimental_variable_policy="not_a_valid_value")
def test_save_debug_info_disabled(self): root = tracking.AutoTrackable() root.f = def_function.function( lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"), input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir, root.f, options=save_options.SaveOptions(save_debug_info=False)) debug_info_file_name = os.path.join(save_dir, "debug", "saved_model_debug_info.pb") self.assertFalse(os.path.exists(debug_info_file_name))
def testCacheWithinSaveContext(self): @def_function.function def func(x): return 2 * x func_a = func.get_concrete_function(constant_op.constant(2.)) func_b = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_b) with save_context.save_context(save_options.SaveOptions()): func_c = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_c)
def testCacheWithinSaveContext(self): @def_function.function def func(x): return 2 * x func_a = func.get_concrete_function(constant_op.constant(2.)) func_b = func.get_concrete_function(constant_op.constant(2.)) self.assertIs(func_a, func_b) with save_context.save_context( save_options.SaveOptions(experimental_variable_policy=save_options .VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)): func_c = func.get_concrete_function(constant_op.constant(2.)) with save_context.save_context( save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE)): func_d = func.get_concrete_function(constant_op.constant(2.)) self.assertIsNot(func_a, func_c) self.assertIsNot(func_a, func_d)
def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto to a file.""" options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) meta_graph_def, exported_graph, _, _ = _build_meta_graph( obj, export_dir, signatures, options) file_io.atomic_write_string_to_file( filename, meta_graph_def.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def testGraphDebugInfo(self): """Test a SavedModel has debug info captured.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) options = save_options.SaveOptions(save_debug_info=True) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save(root, save_dir, to_save, options) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) converter.convert() self._assertValidDebugInfo(converter._debug_info)
def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def test_function_aliases(self): root = tracking.AutoTrackable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") options = save_options.SaveOptions(function_aliases={ "my_func": root.f, }) save.save(root, save_dir, root.f, options=options) function_cache = list(root.f._stateful_fn._function_cache.all_values()) function_aliases = loader_impl.parse_saved_model( save_dir).meta_graphs[0].meta_info_def.function_aliases self.assertLen(function_cache, 1) self.assertEqual(function_cache[0].name.decode("utf-8"), list(function_aliases.keys())[0])
def main(args): if len(args) != 3: print("Expected: {export_path} {ModuleName}") print("Allowed ModuleNames:", MODULE_CTORS.keys()) return 1 _, export_path, module_name = args module_ctor = MODULE_CTORS.get(module_name) if not module_ctor: print("Expected ModuleName to be one of:", MODULE_CTORS.keys()) return 2 os.makedirs(export_path) tf_module = module_ctor() options = save_options.SaveOptions(save_debug_info=True) saved_model.save(tf_module, export_path, options=options)
def test_multi_thread(self): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=True) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) entered_context_in_thread = threading.Event() continue_thread = threading.Event() def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() t = threading.Thread(target=thread_fn) t.start() entered_context_in_thread.wait() # Another thread shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) continue_thread.set() t.join() # Another thread exiting SaveContext shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def test_expand_distributed_variables(self, expand_strategy): context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: context.context().set_logical_device_configuration( cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) context.ensure_initialized() file_name = os.path.join(self.get_temp_dir(), "saved_model.pb") with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope(): root = tracking.AutoTrackable() root.v = variables.Variable([1., 1.], name="v") @def_function.function(input_signature=[]) def f(): root.v.assign([2., 2.]) root.f = f save.export_meta_graph( obj=root, filename=file_name, options=save_options.SaveOptions( experimental_variable_policy=expand_strategy)) graph_def = meta_graph.read_meta_graph_file(file_name).graph_def v0 = next((n for n in graph_def.node if n.name == "v"), None) v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None) self.assertIsNotNone(v0) saved_function = next((f for f in graph_def.library.function if "inference_f_" in f.signature.name), None) self.assertIsNotNone(saved_function) if (expand_strategy == save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES): self.assertIsNotNone(v1) # experimental_save_variable_devices should have been automatically set. self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) self.assertLen(saved_function.signature.input_arg, 2) else: self.assertIsNone(v1) self.assertEmpty(v0.device) # TODO(b/159752793): There should be only one input here. self.assertLen(saved_function.signature.input_arg, 2)
def test_save_load_io_device(self, model_and_input, distribution): saved_dir = os.path.join(self.get_temp_dir(), 'io_device') with distribution.scope(): model = model_and_input.get_model() x_train, y_train, _ = model_and_input.get_data() batch_size = model_and_input.get_batch_size() self._train_model(model, x_train, y_train, batch_size) call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None)) save_options = save_options_lib.SaveOptions( experimental_io_device='/job:localhost') saved_model.save(model, saved_dir, signatures=call, options=save_options) load_options = load_options_lib.LoadOptions( experimental_io_device='/job:localhost') # Check that the model can be loaded and training continued without error. with distribution.scope(): loaded_model = saved_model.load(saved_dir, options=load_options) self._train_model(loaded_model, x_train, y_train, batch_size)
def _test(f, v): # This verifies that the function under SaveContext: # - contains no device annotations. # - only references the primary component of the variable. g = def_function.function(lambda: _discard_return(f)) options = save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE) with save_context.save_context(options): # The graph should contain no device. graph = g.get_concrete_function().graph for op in graph.get_operations(): self.assertEqual(op.device, "", msg=str(op)) # The function should only capture the primary variable. Note that it # may not have captures, e.g. v.aggregation. captures = list(graph.captures) self.assertLessEqual(len(captures), 1) if graph.captures: self.assertIs(captures[0][0], v._primary.handle)
def main(args): if len(args) != 3: print("Expected: {export_path} {ModuleName}") print("Allowed ModuleNames:", MODULE_CTORS.keys()) return 1 _, export_path, module_name = args module_ctor, version = MODULE_CTORS.get(module_name) if not module_ctor: print("Expected ModuleName to be one of:", MODULE_CTORS.keys()) return 2 os.makedirs(export_path) tf_module = module_ctor() if version == 2: options = save_options.SaveOptions(save_debug_info=True) saved_model.save(tf_module, export_path, options=options) else: builder = saved_model.builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables(tf_module, ["serve"]) builder.save()
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example, serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently, variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead of using the calling context's device. This means for example that exporting a model that runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. A single tf.function can generate many ConcreteFunctions. If a downstream tool wants to refer to all concrete functions generated by a single tf.function you can use the `function_aliases` argument to store a map from the alias name to all concrete function names. E.g. ```python class MyModel: @tf.function def func(): ... @tf.function def serve(): ... func() model = MyModel() signatures = { 'serving_default': model.serve.get_concrete_function(), } options = tf.saved_model.SaveOptions(function_aliases={ 'my_func': func, }) tf.saved_model.save(model, export_dir, signatures, options) ``` Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() _, exported_graph, object_saver, asset_info = _build_meta_graph( obj, export_dir, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out # the SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) object_saver.save(utils_impl.get_variables_path(export_dir), options=ckpt_options) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. if context.executing_eagerly(): try: context.async_wait() # Ensure save operations have completed. except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to save on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.SaveOptions " "to the io_device such as '/job:localhost'.") path = os.path.join(compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) options = options or save_options.SaveOptions() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures = signature_serialization.canonicalize_signatures(signatures) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) object_graph_proto = _serialize_object_graph( saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join( utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def test_accepts_io_device(self): options = save_options.SaveOptions() self.assertIsNone(options.experimental_io_device) options = save_options.SaveOptions(experimental_io_device="/job:localhost") self.assertEqual("/job:localhost", options.experimental_io_device)
def test_enter_multiple(self): options = save_options.SaveOptions() with self.assertRaisesRegex(ValueError, 'already in a SaveContext'): with save_context.save_context(options): with save_context.save_context(options): pass
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs): super(ModelCheckpoint, self).__init__() self.filepaths = [] self._supports_tf_logs = True self.monitor = monitor self.verbose = verbose self.filepath = tf.python.keras.utils.io_utils.path_to_string(filepath) self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.save_freq = save_freq self.epochs_since_last_save = 0 self._batches_seen_since_last_saving = 0 self._last_batch_seen = 0 if save_weights_only: if options is None or isinstance( options, checkpoint_options_lib.CheckpointOptions): self._options = options or checkpoint_options_lib.CheckpointOptions( ) else: raise TypeError( 'If save_weights_only is True, then `options` must be' 'either None or a tf.train.CheckpointOptions') else: if options is None or isinstance(options, save_options_lib.SaveOptions): self._options = options or save_options_lib.SaveOptions() else: raise TypeError( 'If save_weights_only is False, then `options` must be' 'either None or a tf.saved_model.SaveOptions') # Deprecated field `load_weights_on_restart` is for loading the checkpoint # file from `filepath` at the start of `model.fit()` # TODO(rchao): Remove the arg during next breaking release. if 'load_weights_on_restart' in kwargs: self.load_weights_on_restart = kwargs['load_weights_on_restart'] logging.warning( '`load_weights_on_restart` argument is deprecated. ' 'Please use `model.load_weights()` for loading weights ' 'before the start of `model.fit()`.') else: self.load_weights_on_restart = False # Deprecated field `period` is for the number of epochs between which # the model is saved. if 'period' in kwargs: self.period = kwargs['period'] logging.warning( '`period` argument is deprecated. Please use `save_freq` ' 'to specify the frequency in number of batches seen.') else: self.period = 1 if mode not in ['auto', 'min', 'max']: logging.warning( 'ModelCheckpoint mode %s is unknown, ' 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf if self.save_freq != 'epoch' and not isinstance(self.save_freq, int): raise ValueError('Unrecognized save_freq: {}'.format( self.save_freq)) # Only the chief worker writes model checkpoints, but all workers # restore checkpoint at on_train_begin(). self._chief_worker_only = False