def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensors = [] slice_specs = [] for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): for slice_spec, tensor in tensor_slices.items(): if isinstance(tensor, saveable_object.SaveSpec): tensor_value = tensor.tensor # A tensor value of `None` indicates that this SaveableObject gets # recorded in the object graph, but that no value is saved in the # checkpoint. if tensor_value is not None: tensor_names.append(tensor.name) tensors.append(tensor_value) slice_specs.append(tensor.slice_spec) else: tensor_names.append(checkpoint_key) tensors.append(tensor) slice_specs.append(slice_spec) save_device = options.experimental_io_device or "cpu:0" with ops.device(save_device): return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ options = options or checkpoint_options.CheckpointOptions() restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2(file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = {} for saveable, restored_tensors in zip(self._saveable_objects, structured_restored_tensors): restore_ops[saveable.name] = saveable.restore(restored_tensors, restored_shapes=None) return restore_ops
def setUp(self): super(SaverTest, self).setUp() cpus = config.list_physical_devices("CPU") # Set 3 virtual CPUs config.set_logical_device_configuration(cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) self.local_options = checkpoint_options.CheckpointOptions( experimental_io_device=LOCALHOST)
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: When not run eagerly or when saving on a single device, returns a dictionary mapping from SaveableObject names to restore operations; otherwise, returns an empty dict. """ options = options or checkpoint_options.CheckpointOptions() def restore_fn(): restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): restore_ops.update(saver.restore(file_prefix, options)) for _, (_, restore_fn) in self._registered_savers.items(): restore_fn(file_prefix) return restore_ops # Since this will causes a function re-trace on each restore, limit this to # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: @def_function.function(jit_compile=False) def tf_function_restore(): restore_fn() return {} restore_ops = tf_function_restore() else: restore_ops = restore_fn() for callback in self._after_restore_callbacks: callback() return restore_ops
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensor_dtypes = [] slice_specs = [] for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): for slice_spec, tensor in tensor_slices.items(): tensor_dtypes.append(tensor.dtype) if isinstance(tensor, saveable_object.SaveSpec): slice_specs.append(tensor.slice_spec) tensor_names.append(tensor.name) else: slice_specs.append(slice_spec) tensor_names.append(checkpoint_key) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, slice_specs, tensor_dtypes) restored_tensor_dict = {} for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): for slice_spec in tensor_slices: restored_tensor = restored_tensors.pop(0) restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = ( restored_tensor) return restored_tensor_dict
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: When not run eagerly or when saving on a single device, returns a dictionary mapping from SaveableObject names to restore operations; otherwise, returns an empty dict. """ options = options or checkpoint_options.CheckpointOptions() def restore_fn(): restore_fn_inputs = {} restore_fn_input_count = { fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()} restore_ops = {} # Sort by device name to avoid propagating non-deterministic dictionary # ordering in some Python versions. for device, saver in sorted(self._single_device_savers.items()): with ops.device(device): # Load values from checkpoint restored_tensor_dict = saver.restore(file_prefix, options) # Map restored tensors to the corresponding restore_fn, and see if all # inputs have all been loaded. Call `restore_fn` if that is the case. for checkpoint_key, slice_and_tensor in restored_tensor_dict.items(): for slice_spec, tensor in slice_and_tensor.items(): restore_fn = self._keys_to_restore_fn[(checkpoint_key, slice_spec)] (restore_fn_inputs .setdefault(restore_fn, {}) .setdefault(checkpoint_key, {})[slice_spec]) = tensor restore_fn_input_count[restore_fn] -= 1 if restore_fn_input_count[restore_fn] == 0: ret = restore_fn(restore_fn_inputs[restore_fn]) if isinstance(ret, dict): restore_ops.update(ret) # Run registered restore methods after the default restore ops. for _, (_, restore_fn) in self._registered_savers.items(): restore_fn(file_prefix) return restore_ops restore_device = options.experimental_io_device or "cpu:0" # Since this will causes a function re-trace on each restore, limit this to # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and (len(self._single_device_savers) > 1 or options.experimental_io_device): @def_function.function(jit_compile=False) def tf_function_restore(): restore_fn() return {} with ops.device(restore_device): restore_ops = tf_function_restore() else: restore_ops = restore_fn() return restore_ops
def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() # IMPLEMENTATION DETAILS: most clients should skip. # # Suffix for any well-formed "checkpoint_prefix", when sharded. # Transformations: # * Users pass in "save_path" in save() and restore(). Say "myckpt". # * checkpoint_prefix gets fed <save_path><sharded_suffix>. # # Example: # During runtime, a temporary directory is first created, which contains # files # # <train dir>/myckpt_temp/ # part-?????-of-?????{.index, .data-00000-of-00001} # # Before .save() finishes, they will be (hopefully, atomically) renamed to # # <train dir>/ # myckpt{.index, .data-?????-of-?????} # # Filesystems with eventual consistency (such as S3), don't need a # temporary location. Using a temporary directory in those cases might # cause situations where files are not available during copy. # # Users only need to interact with the user-specified prefix, which is # "<train dir>/myckpt" in this case. Save() and Restore() work with the # prefix directly, instead of any physical pathname. (On failure and # subsequent restore, an outdated and orphaned temporary directory can be # safely removed.) with ops.device("CPU"): sharded_suffix = array_ops.where( string_ops.regex_full_match(file_prefix, "^s3://.*"), constant_op.constant(".part"), constant_op.constant("_temp/part")) tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) registered_paths = { saver_name: registered_saver_filename(file_prefix, saver_name) for saver_name in self._registered_savers } def save_fn(): saved_prefixes = [] # Save with the registered savers. These run before default savers due to # the API contract. for saver_name, (save_fn, _) in self._registered_savers.items(): maybe_saved_prefixes = save_fn(registered_paths[saver_name]) if maybe_saved_prefixes is not None: flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes) if not all( tensor_util.is_tf_type(x) and x.dtype == dtypes.string for x in flattened_saved_prefixes): raise ValueError( "Registered saver must return a (maybe empty) list of " f"string type tensors. Got {maybe_saved_prefixes}.") saved_prefixes.extend(flattened_saved_prefixes) # (Default saver) Save with single device savers. num_shards = len(self._single_device_savers) sharded_saves = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) saved_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but # initial read operations should be placed on the SaveableObject's # device. sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. merge_device = ( options.experimental_io_device or saveable_object_util.set_cpu0(last_device)) with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints( saved_prefixes, file_prefix, delete_old_dirs=True) # Since this will causes a function re-trace on each save, limit this to the # cases where it is needed: eager and when there are multiple tasks/single # device savers. Note that the retrace is needed to ensure we pickup the # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: # Explicitly place the identity op on the first device. @def_function.function(jit_compile=False) def tf_function_save(): save_fn() tf_function_save() else: return save_fn()
def restore(self, save_path, options=None): """Restore a training checkpoint with host mesh placement.""" options = options or checkpoint_options.CheckpointOptions() if save_path is None: return util.InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access save_path=save_path, dtype_map=dtype_map) if not graph_building: for existing_trackable in self._graph_view.list_objects(): # pylint: disable=protected-access existing_trackable._maybe_initialize_trackable() existing_trackable._name_based_restores.add( restore_coordinator) existing_trackable._name_based_attribute_restore( restore_coordinator) # pylint: enable=protected-access return util.NameBasedSaverStatus(restore_coordinator, graph_view=self._graph_view) if graph_building: if self._file_prefix_placeholder is None: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. self._file_prefix_placeholder = api.pack( [constant_op.constant("model")] * self._mesh.num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. file_prefix_tensor = api.pack([constant_op.constant(save_path)] * self._mesh.num_local_devices(), layout.Layout.replicated( self._mesh.host_mesh(), rank=0)) file_prefix_feed_dict = None object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) # DTensor Change: Hook the proper DSaver in restore. checkpoint = _DCheckpointRestoreCoordinator( mesh=self._mesh, object_graph_proto=object_graph_proto, save_path=save_path, save_path_tensor=file_prefix_tensor, reader=reader, restore_op_cache=self._restore_op_cache, graph_view=self._graph_view, options=options, saveables_cache=self._saveables_cache) base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) # Attached dependencies are not attached to the root, so should be restored # separately. if self._graph_view.attached_dependencies: for ref in self._graph_view.attached_dependencies: if ref.name == "root": # Root dependency is automatically added to attached dependencies -- # this can be ignored since it maps back to the root object. continue proto_id = None # Find proto ID of attached dependency (if it is in the proto). for proto_ref in object_graph_proto.nodes[0].children: if proto_ref.local_name == ref.name: proto_id = proto_ref.node_id break if proto_id in checkpoint.object_by_proto_id: # Object has already been restored. This can happen when there's an # indirect connection from the attached object to the root. continue base.CheckpointPosition(checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) load_status = util.CheckpointLoadStatus( checkpoint, graph_view=self._graph_view, feed_dict=file_prefix_feed_dict) return load_status
def load_partial(export_dir, filters, tags=None, options=None): """Partially load a SavedModel (saved from V2). Similar to `tf.saved_model.load`, but with an additional argument that lets you specify which nodes to load. `tf.saved_model.load_partial(export_dir, ["root"])` and `tf.saved_model.load(export_dir)` are equivalent. Note: This only works for SavedModels saved with TensorFlow V2 from `tf.saved_model.save` or Keras. This will not load SavedModels save from the Estimator API. In Tensorflow V2, SavedModel stores the **object graph** of the saved object. The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras layers, etc.) and edges that are the name of the attributes connecting the objects. *Example 1* ``` model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) tf.saved_model.save(model, '/tmp/model') loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.child_layer', 'root.child_layer.v']) loaded['root.child_layer'].v.numpy() 5. loaded['root.child_layer'].v is loaded['root.child_layer.v'] True *Example 2* model = tf.Module() model.child_layer = tf.Module() model.child_layer.v = tf.Variable(5.) >>> tf.saved_model.save(model, '/tmp/model') # Create a variable new_variable = tf.Variable(0.) loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) loaded['root.child_layer'].v.numpy() 5. new_variable.numpy() 5. ``` **Loading under different distribution strategies** You can load different parts of the model under different distribution strategies. Note that this is very experimental so use with care. ``` model = tf.Module() model.layer_1 = tf.Module() model.layer_1.v = tf.Variable(5.) model.layer_2 = tf.Module() model.layer_2.v = tf.Variable(7.) tf.saved_model.save(model, '/tmp/model') # Load with no strategy loaded = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_1']) loaded['root.layer_1'].v <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> strategy = tf.distribute.MirroredStrategy() with strategy.scope(): ... loaded2 = tf.__internal__.saved_model.load_partial( ... '/tmp/model', ... ['root.layer_2']) loaded2['root.layer_2'].v MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> } ``` Args: export_dir: The SavedModel directory to load from. filters: A list or dictionary where each element or key is a string path to nodes that should be loaded. Node paths consist of all the child attribute names to reach that node in the form: `root.{attribute_name}`. The loader will load all of the specified nodes and their recursive descendants. When this option is defined, the loader will return a dictionary mapping the node paths to the loaded objects. tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.save`. options: `tf.saved_model.LoadOptions` object that specifies options for loading. Returns: A dictionary mapping node paths from the filter to loaded objects. """ options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered # sequences for nest.flatten, so we put those through as-is. tags = nest.flatten(tags) saved_model_proto, debug_info = ( loader_impl.parse_saved_model_with_debug_info(export_dir)) if (len(saved_model_proto.meta_graphs) == 1 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): metrics.IncrementReadApi(_LOAD_V2_LABEL) meta_graph_def = saved_model_proto.meta_graphs[0] # tensor_content field contains raw bytes in litle endian format # which causes problems when loaded on big-endian systems # requiring byteswap if sys.byteorder == "big": saved_model_utils.swap_function_tensor_content( meta_graph_def, "little", "big") if (tags is not None and set(tags) != set(meta_graph_def.meta_info_def.tags)): raise ValueError( f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " f"{export_dir} has one MetaGraph with tags " f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " "pass 'None', or pass matching tags.") object_graph_proto = meta_graph_def.object_graph_def ckpt_options = checkpoint_options.CheckpointOptions( experimental_io_device=options.experimental_io_device) with ops.init_scope(): try: loader = Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n You may be trying to load on a different device " "from the computational device. Consider setting the " "`experimental_io_device` option in `tf.saved_model.LoadOptions` " "to the io_device such as '/job:localhost'.") root = loader.get(0) root.graph_debug_info = loader.adjust_debug_info_func_names( debug_info) root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) metrics.IncrementRead(write_version="2") else: if filters: raise ValueError( "SavedModels saved from Tensorflow 1.x or Estimator (any" " version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info if filters: return {node_id: loader.get(node_id) for node_id in filters} else: return {"root": root}