def test_inititialize_with_data_structures(self, enable_async_ckpt): if enable_async_ckpt and not context.executing_eagerly(): self.skipTest( "Skipping this test as async checkpoint does not support graph mode.") checkpoint = trackable_utils.Checkpoint( a=[variables_lib.Variable(0.), variables_lib.Variable(1.)], b={"a": variables_lib.Variable(2.), "b": variables_lib.Variable(3.)}) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") ckpt_options = checkpoint_options.CheckpointOptions( experimental_enable_async_checkpoint=enable_async_ckpt) save_path = checkpoint.save(file_prefix=checkpoint_prefix, options=ckpt_options) load_checkpoint = trackable_utils.Checkpoint( a=[variables_lib.Variable(4.), variables_lib.Variable(5.)], b={"a": variables_lib.Variable(6.), "b": variables_lib.Variable(7.)}) # When async checkpoint is enabled, we need to first make sure that the # checkpoint saving is fully complete before the checkpoint file can be # loaded by another checkpoint instance. Calling checkpoint.restore() is a # trick to make sure its async thread is joined. if enable_async_ckpt: checkpoint.restore(save_path) load_checkpoint.restore(save_path) self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1]) self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
def testCheckpointSaveRestoreIoDevice(self, distribution): def state(): with distribution.scope(): v = variables_lib.Variable(random_ops.random_normal([])) return v ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device="/job:localhost") def checkpoint(): v = state() # Save random weights into checkpoint. checkpoint = trackable_utils.Checkpoint(v=v) prefix = os.path.join(self.get_temp_dir(), "ckpt") with self.test_session(): save_path = checkpoint.save(prefix, options=ckpt_options) return save_path save_path = checkpoint() v = state() checkpoint = trackable_utils.Checkpoint(v=v) # Restore from the checkpoint inside a distribution.scope(). # Check that restore works without error. with self.test_session(): with distribution.scope(): checkpoint.restore(save_path, options=ckpt_options)
def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensors = [] tensor_slices = [] for saveable in self._saveable_objects: for spec in saveable.specs: tensor = spec.tensor # A tensor value of `None` indicates that this SaveableObject gets # recorded in the object graph, but that no value is saved in the # checkpoint. if tensor is not None: tensor_names.append(spec.name) tensors.append(tensor) tensor_slices.append(spec.slice_spec) save_device = options.experimental_io_device or "cpu:0" with ops.device(save_device): return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ options = options or checkpoint_options.CheckpointOptions() restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = {} for saveable, restored_tensors in zip(self._saveable_objects, structured_restored_tensors): restore_ops[saveable.name] = saveable.restore( restored_tensors, restored_shapes=None) return restore_ops
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, filters=None): """Loader implementation.""" options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( ("The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.") .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, ckpt_options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to load on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.LoadOptions " "to the io_device such as '/job:localhost'." ) root = loader.get(0) if isinstance(loader, Loader): root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) else: if filters: raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any " "version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}
def setUp(self): super(SaverTest, self).setUp() cpus = config.list_physical_devices("CPU") # Set 3 virtual CPUs config.set_logical_device_configuration(cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) self.local_options = checkpoint_options.CheckpointOptions( experimental_io_device=LOCALHOST)
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ options = options or checkpoint_options.CheckpointOptions() def restore_fn(): restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): restore_ops.update(saver.restore(file_prefix, options)) return restore_ops # Since this will causes a function re-trace on each restore, limit this to # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: first_device, _ = list(self._single_device_savers.items())[0] @def_function.function(jit_compile=False) def tf_function_restore(): restore_ops = restore_fn() restore_tensors = {} # tf.functions must return tensors, thus we use control dependencies so # that we can return a tensor which depends on the given op. with ops.device(saveable_object_util.set_cpu0(first_device)): for name, op in restore_ops.items(): with ops.control_dependencies([op]): restore_tensors[name] = array_ops.identity( file_prefix) return restore_tensors restore_ops = tf_function_restore() else: restore_ops = restore_fn() for callback in self._after_restore_callbacks: callback() return restore_ops
def testAssertConsumedNoCheckpoint(self, enable_async_ckpt): if enable_async_ckpt and not context.executing_eagerly(): self.skipTest( "Skipping this test as async checkpoint does not support graph mode.") prefix = os.path.join(self.get_temp_dir(), "ckpt") v = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(v.initializer) ckpt = trackable_utils.Checkpoint(v=v) self.evaluate(trackable_utils.gather_initializers(ckpt)) ckpt_options = checkpoint_options.CheckpointOptions( experimental_enable_async_checkpoint=enable_async_ckpt) save_path = ckpt.save(file_prefix=prefix, options=ckpt_options) status = ckpt.restore(save_path=save_path) del ckpt status.assert_consumed()
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader): """Loader implementation.""" options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): meta_graph_def = saved_model_proto.meta_graphs[0] if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( ("The SavedModel at {} has one MetaGraph with tags {}, but got an " "incompatible argument tags={} to tf.saved_model.load. You may omit " "it, pass 'None', or pass matching tags.") .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, ckpt_options) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to load on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.LoadOptions " "to the io_device such as '/job:localhost'." ) root = loader.get(0) if isinstance(loader, Loader): root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info return root
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: When not run eagerly or when saving on a single device, returns a dictionary mapping from SaveableObject names to restore operations; otherwise, returns an empty dict. """ options = options or checkpoint_options.CheckpointOptions() def restore_fn(): restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): restore_ops.update(saver.restore(file_prefix, options)) for _, (_, restore_fn) in self._registered_savers.items(): restore_fn(file_prefix) return restore_ops # Since this will causes a function re-trace on each restore, limit this to # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: @def_function.function(jit_compile=False) def tf_function_restore(): restore_fn() return {} restore_ops = tf_function_restore() else: restore_ops = restore_fn() for callback in self._after_restore_callbacks: callback() return restore_ops
def testCustomNumbering(self, enable_async_ckpt): if enable_async_ckpt and not context.executing_eagerly(): self.skipTest( "Skipping this test as async checkpoint does not support graph mode.") directory = self.get_temp_dir() prefix = os.path.join(directory, "ckpt") step = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) checkpoint = trackable_utils.Checkpoint(step=step) ckpt_options = checkpoint_options.CheckpointOptions( experimental_enable_async_checkpoint=enable_async_ckpt) self.evaluate(step.initializer) for i in range(5): path = checkpoint.write("%s-%d" % (prefix, self.evaluate(step)), options=ckpt_options) expected_suffix = "-%d" % (2 * i,) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2))
def testPassingCheckpointOptions(self): localhost = "/job:localhost/device:CPU:0" options = checkpoint_options.CheckpointOptions( experimental_io_device=localhost) prefix = os.path.join(self.get_temp_dir(), "ckpt") v = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(v.initializer) ckpt = trackable_utils.Checkpoint(v=v) self.evaluate(trackable_utils.gather_initializers(ckpt)) save_path = ckpt.save(file_prefix=prefix, options=options) status = ckpt.restore(save_path=save_path, options=options) del ckpt status.assert_consumed() # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2"): self.assertEqual(localhost, op.device)
def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensors = [] tensor_slices = [] for saveable in self._saveable_objects: for spec in saveable.specs: tensor_names.append(spec.name) tensors.append(spec.tensor) tensor_slices.append(spec.slice_spec) save_device = options.experimental_io_device or "cpu:0" with ops.device(save_device): return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ options = options or checkpoint_options.CheckpointOptions() restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): restore_ops.update(saver.restore(file_prefix, options)) for callback in self._after_restore_callbacks: callback() return restore_ops
def testMoreComplexSaveableReturned(self, enable_async_ckpt): if enable_async_ckpt and not context.executing_eagerly(): self.skipTest( "Skipping this test as async checkpoint does not support graph mode.") v = _OwnsMirroredVariables() checkpoint = trackable_utils.Checkpoint(v=v) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") self.evaluate(v.non_dep_variable.assign(42.)) ckpt_options = checkpoint_options.CheckpointOptions( experimental_enable_async_checkpoint=enable_async_ckpt) save_path = checkpoint.save(file_prefix=prefix, options=ckpt_options) self.evaluate(v.non_dep_variable.assign(43.)) self.evaluate(v.mirrored.assign(44.)) checkpoint.restore(save_path).assert_consumed().initialize_or_restore() self.assertEqual(42., self.evaluate(v.non_dep_variable)) self.assertEqual(42., self.evaluate(v.mirrored)) self.evaluate(v.non_dep_variable.assign(44.)) save_path = checkpoint.save(file_prefix=prefix, options=ckpt_options) self.evaluate(v.non_dep_variable.assign(45.)) checkpoint.restore(save_path).assert_consumed().initialize_or_restore() self.assertEqual(44., self.evaluate(v.non_dep_variable)) self.assertEqual(44., self.evaluate(v.mirrored))
def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() for callback in self._before_save_callbacks: callback() # IMPLEMENTATION DETAILS: most clients should skip. # # Suffix for any well-formed "checkpoint_prefix", when sharded. # Transformations: # * Users pass in "save_path" in save() and restore(). Say "myckpt". # * checkpoint_prefix gets fed <save_path><sharded_suffix>. # # Example: # During runtime, a temporary directory is first created, which contains # files # # <train dir>/myckpt_temp/ # part-?????-of-?????{.index, .data-00000-of-00001} # # Before .save() finishes, they will be (hopefully, atomically) renamed to # # <train dir>/ # myckpt{.index, .data-?????-of-?????} # # Filesystems with eventual consistency (such as S3), don't need a # temporary location. Using a temporary directory in those cases might # cause situations where files are not available during copy. # # Users only need to interact with the user-specified prefix, which is # "<train dir>/myckpt" in this case. Save() and Restore() work with the # prefix directly, instead of any physical pathname. (On failure and # subsequent restore, an outdated and orphaned temporary directory can be # safely removed.) with ops.device("CPU"): sharded_suffix = array_ops.where( string_ops.regex_full_match(file_prefix, "^s3://.*"), constant_op.constant(".part"), constant_op.constant("_temp_%s/part" % uuid.uuid4().hex)) tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) def save_fn(): num_shards = len(self._single_device_savers) sharded_saves = [] sharded_prefixes = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but # initial read operations should be placed on the SaveableObject's # device. sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. merge_device = ( options.experimental_io_device or saveable_object_util.set_cpu0(last_device)) with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints( sharded_prefixes, file_prefix, delete_old_dirs=True) # Since this will causes a function re-trace on each save, limit this to the # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: # Explicitly place the identity op on the first device. @def_function.function(experimental_compile=False) def tf_function_save(): save_fn() tf_function_save() else: return save_fn()
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs): super(ModelCheckpoint, self).__init__() self.filepaths = [] self._supports_tf_logs = True self.monitor = monitor self.verbose = verbose self.filepath = tf.python.keras.utils.io_utils.path_to_string(filepath) self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.save_freq = save_freq self.epochs_since_last_save = 0 self._batches_seen_since_last_saving = 0 self._last_batch_seen = 0 if save_weights_only: if options is None or isinstance( options, checkpoint_options_lib.CheckpointOptions): self._options = options or checkpoint_options_lib.CheckpointOptions( ) else: raise TypeError( 'If save_weights_only is True, then `options` must be' 'either None or a tf.train.CheckpointOptions') else: if options is None or isinstance(options, save_options_lib.SaveOptions): self._options = options or save_options_lib.SaveOptions() else: raise TypeError( 'If save_weights_only is False, then `options` must be' 'either None or a tf.saved_model.SaveOptions') # Deprecated field `load_weights_on_restart` is for loading the checkpoint # file from `filepath` at the start of `model.fit()` # TODO(rchao): Remove the arg during next breaking release. if 'load_weights_on_restart' in kwargs: self.load_weights_on_restart = kwargs['load_weights_on_restart'] logging.warning( '`load_weights_on_restart` argument is deprecated. ' 'Please use `model.load_weights()` for loading weights ' 'before the start of `model.fit()`.') else: self.load_weights_on_restart = False # Deprecated field `period` is for the number of epochs between which # the model is saved. if 'period' in kwargs: self.period = kwargs['period'] logging.warning( '`period` argument is deprecated. Please use `save_freq` ' 'to specify the frequency in number of batches seen.') else: self.period = 1 if mode not in ['auto', 'min', 'max']: logging.warning( 'ModelCheckpoint mode %s is unknown, ' 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf if self.save_freq != 'epoch' and not isinstance(self.save_freq, int): raise ValueError('Unrecognized save_freq: {}'.format( self.save_freq)) # Only the chief worker writes model checkpoints, but all workers # restore checkpoint at on_train_begin(). self._chief_worker_only = False
def restore(self, save_path, options=None): """Restore a training checkpoint with host mesh placement.""" options = options or checkpoint_options.CheckpointOptions() if save_path is None: return util.InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access save_path=save_path, dtype_map=dtype_map) if not graph_building: for existing_trackable in self._graph_view.list_objects(): # pylint: disable=protected-access existing_trackable._maybe_initialize_trackable() existing_trackable._name_based_restores.add( restore_coordinator) existing_trackable._name_based_attribute_restore( restore_coordinator) # pylint: enable=protected-access return util.NameBasedSaverStatus(restore_coordinator, graph_view=self._graph_view) if graph_building: if self._file_prefix_placeholder is None: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. self._file_prefix_placeholder = api.pack( [constant_op.constant("model")] * self._mesh.num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. file_prefix_tensor = api.pack([constant_op.constant(save_path)] * self._mesh.num_local_devices(), layout.Layout.replicated( self._mesh.host_mesh(), rank=0)) file_prefix_feed_dict = None object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) # DTensor Change: Hook the proper DSaver in restore. checkpoint = _DCheckpointRestoreCoordinator( mesh=self._mesh, object_graph_proto=object_graph_proto, save_path=save_path, save_path_tensor=file_prefix_tensor, reader=reader, restore_op_cache=self._restore_op_cache, graph_view=self._graph_view, options=options, saveables_cache=self._saveables_cache) base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) # Attached dependencies are not attached to the root, so should be restored # separately. if self._graph_view.attached_dependencies: for ref in self._graph_view.attached_dependencies: if ref.name == "root": # Root dependency is automatically added to attached dependencies -- # this can be ignored since it maps back to the root object. continue proto_id = None # Find proto ID of attached dependency (if it is in the proto). for proto_ref in object_graph_proto.nodes[0].children: if proto_ref.local_name == ref.name: proto_id = proto_ref.node_id break if proto_id in checkpoint.object_by_proto_id: # Object has already been restored. This can happen when there's an # indirect connection from the attached object to the root. continue base.CheckpointPosition(checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) load_status = util.CheckpointLoadStatus( checkpoint, graph_view=self._graph_view, feed_dict=file_prefix_feed_dict) return load_status
def load_partial(export_dir, filters, tags=None, options=None): """Partially load a SavedModel (saved from V2). Similar to `tf.saved_model.load`, but with an additional argument that lets you specify which nodes to load. `tf.saved_model.load_partial(export_dir, ["root"])` and `tf.saved_model.load(export_dir)` are equivalent. Note: This only works for SavedModels saved with TensorFlow V2 from `tf.saved_model.save` or Keras. This will not load SavedModels save from the Estimator API. In Tensorflow V2, SavedModel stores the **object graph** of the saved object. The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras layers, etc.) and edges that are the name of the attributes connecting the objects. *Example 1* ``` model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) tf.saved_model.save(model, '/tmp/model') loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.child_layer', 'root.child_layer.v']) loaded['root.child_layer'].v.numpy() 5. loaded['root.child_layer'].v is loaded['root.child_layer.v'] True *Example 2* model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) >>> tf.saved_model.save(model, '/tmp/model') # Create a variable new_variable = tf.Variable(0.) loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) loaded['root.child_layer'].v.numpy() 5. new_variable.numpy() 5. ``` **Loading under different distribution strategies** You can load different parts of the model under different distribution strategies. Note that this is very experimental so use with care. ``` model = tf.Module() model.layer_1 = tf.Module() model.layer_1.v = tf.Variable(5.) model.layer_2 = tf.Module() model.layer_2.v = tf.Variable(7.) tf.saved_model.save(model, '/tmp/model') # Load with no strategy loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_1']) loaded['root.layer_1'].v <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> strategy = tf.distribute.MirroredStrategy() with strategy.scope(): ... loaded2 = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_2']) loaded2['root.layer_2'].v MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> } ``` Args: export_dir: The SavedModel directory to load from. filters: A list or dictionary where each element or key is a string path to nodes that should be loaded. Node paths consist of all the child attribute names to reach that node in the form: `root.{attribute_name}`. The loader will load all of the specified nodes and their recursive descendants. When this option is defined, the loader will return a dictionary mapping the node paths to the loaded objects. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.save`. options: `tf.saved_model.LoadOptions` object that specifies options for loading. Returns: A dictionary mapping node paths from the filter to loaded objects. """ options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example, serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently, variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead of using the calling context's device. This means for example that exporting a model that runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. A single tf.function can generate many ConcreteFunctions. If a downstream tool wants to refer to all concrete functions generated by a single tf.function you can use the `function_aliases` argument to store a map from the alias name to all concrete function names. E.g. ```python class MyModel: @tf.function def func(): ... @tf.function def serve(): ... func() model = MyModel() signatures = { 'serving_default': model.serve.get_concrete_function(), } options = tf.saved_model.SaveOptions(function_aliases={ 'my_func': func, }) tf.saved_model.save(model, export_dir, signatures, options) ``` Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() _, exported_graph, object_saver, asset_info = _build_meta_graph( obj, export_dir, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out # the SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) object_saver.save(utils_impl.get_variables_path(export_dir), options=ckpt_options) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. if context.executing_eagerly(): try: context.async_wait() # Ensure save operations have completed. except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to save on a different device from the " "computational device, consider using setting the " "`experimental_io_device` option on tf.saved_model.SaveOptions " "to the io_device such as '/job:localhost'.") path = os.path.join(compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)