def __init__(self, obj, name, call_with_mapped_captures=None): self._trackable = obj self._call_with_mapped_captures = call_with_mapped_captures save_fn = obj._serialize_to_tensors # pylint: disable=protected-access if (call_with_mapped_captures and isinstance(save_fn, core.ConcreteFunction)): tensor_dict = call_with_mapped_captures(save_fn, []) else: tensor_dict = save_fn() specs = [] self._local_names = [] self._prefix = saveable_compat.get_saveable_name(self._trackable) or "" for tensor_name, maybe_tensor in tensor_dict.items(): self._local_names.append(tensor_name) spec_name = name + trackable_utils.escape_local_name(tensor_name) if not isinstance(maybe_tensor, dict): maybe_tensor = {"": maybe_tensor} # Create separate specs for each slice spec. for slice_spec, tensor in maybe_tensor.items(): specs.append( saveable_object.SaveSpec(tensor, slice_spec, spec_name)) super(TrackableSaveable, self).__init__(obj, specs, name)
def _create_serialize_to_tensor_saveable(self, saveable_factories): """Creates a saveable using the _serialize_to_tensor method.""" # Extract the saveable name from the checkpoint key. This will be used as # the cache key or the name to pass to the saveable factory. suffix = saveable_compat.get_saveable_name(self.trackable) or "" saveable_name = _extract_saveable_name( self.object_proto.attributes[0].checkpoint_key) + suffix # Try to find the cached saveable (only in graph mode). if not context.executing_eagerly(): existing_op = self._checkpoint.restore_ops_by_name.get( saveable_name, None) if existing_op is not None: return existing_op, {} saveables_cache = self._checkpoint.saveables_cache.setdefault( self.trackable, {}) if saveable_name in saveables_cache: return [], {saveable_name: saveables_cache[saveable_name]} saveable = saveable_factories[ trackable_utils.SERIALIZE_TO_TENSORS_NAME](name=saveable_name) if not context.executing_eagerly(): saveables_cache[saveable_name] = saveable return [], {saveable_name: saveable}
def extract_saveable_name(trackable, checkpoint_key): if saveable_compat.get_saveable_name(trackable) is not None: # If there is a legacy saveable name, the saveable name is the checkpoint # key. return checkpoint_key # Substring the checkpoint key to the end of the ".ATTRIBUTES/" (len=12) return checkpoint_key[:checkpoint_key.index( trackable_utils.OBJECT_ATTRIBUTES_NAME) + 12]
def get_checkpoint_factories_and_keys(object_names, object_map=None): """Gets a map of saveable factories and corresponding checkpoint keys. Args: object_names: a dictionary that maps `Trackable` objects to auto-generated string names. object_map: a dictionary mapping `Trackable` to copied `Trackable` objects. The copied objects are generated from `Trackable._map_resources()` which copies the object into another graph. Generally only resource objects (e.g. Variables, Tables) will be in this map. Returns: A tuple of ( Dictionary mapping trackable -> list of _CheckpointFactoryData, Dictionary mapping registered saver name -> {object name -> trackable}) """ checkpoint_factory_map = object_identity.ObjectIdentityDictionary() unmapped_registered_savers = collections.defaultdict(dict) for trackable, object_name in object_names.items(): # object_to_save is only used to retrieve the saving functionality. For keys # and other data, use the original `trackable`. object_to_save = util.get_mapped_trackable(trackable, object_map) saver_name = registration.get_registered_saver_name(object_to_save) if saver_name: # Add the original trackable instead of `object_to_save` to the returned # dict because the original is needed for writing the object proto. unmapped_registered_savers[saver_name][object_name] = trackable else: checkpoint_factory_map[trackable] = [] for name, saveable_factory in ( saveable_object_util.saveable_objects_from_trackable( object_to_save).items()): # pylint: disable=protected-access # Retrieve the legacy saveable name (for compatibility purposes during # SaveableObject deprecation) key_suffix = saveable_compat.get_saveable_name( object_to_save) or name checkpoint_key = trackable_utils.checkpoint_key( object_name, key_suffix) if not saveable_compat.force_checkpoint_conversion_enabled(): # Make sure the set the name as the legacy saveable name if there # is one (only when checkpoint conversion is diabled) name = key_suffix checkpoint_factory_map[trackable].append( _CheckpointFactoryData(factory=saveable_factory, name=name, checkpoint_key=checkpoint_key)) return checkpoint_factory_map, unmapped_registered_savers
def trace_save_and_restore(obj): """Traces `Trackable` serialize- and restore-from-tensors functions. Args: obj: A `Trackable` object. Returns: A concrete Function. """ legacy_name = saveable_compat.get_saveable_name(obj) obj_save_fn = obj._serialize_to_tensors # pylint: disable=protected-access obj_restore_fn = obj._restore_from_tensors # pylint: disable=protected-access if isinstance(obj_save_fn, defun.ConcreteFunction): concrete_save = obj_save_fn else: @def_function.function def save_fn(): tensor_dict = obj_save_fn() if legacy_name: # If there is a legacy decorator, append the name to the keys. return { f"{legacy_name}{key}": value for key, value in tensor_dict.items() } return tensor_dict concrete_save = save_fn.get_concrete_function() if isinstance(obj_restore_fn, defun.ConcreteFunction): concrete_restore = obj_restore_fn else: @def_function.function def restore_fn(restored_tensors): if legacy_name: # Do the opposite operation of save_fn() restored_tensors = { key[len(legacy_name):]: value for key, value in restored_tensors.items() } obj_restore_fn(restored_tensors) concrete_restore = restore_fn.get_concrete_function( concrete_save.structured_outputs) return concrete_save, concrete_restore
def test_multiple_specs_single_saveable(self): class MyTrackable(base.Trackable): def __init__(self): self.a = variables.Variable(35.0) self.b = variables.Variable(40.0) def _gather_saveables_for_checkpoint(self): return {"foo": lambda name: _MultiSpecSaveable(self, name)} t = MyTrackable() converter = saveable_object_util.SaveableCompatibilityConverter(t) serialized_tensors = converter._serialize_to_tensors() self.assertLen(serialized_tensors, 2) self.assertEqual(35, self.evaluate(serialized_tensors["foo-a"])) self.assertEqual(40, self.evaluate(serialized_tensors["foo-b"])) converter._restore_from_tensors({"foo-a": 5., "foo-b": 6.}) self.assertEqual(5, self.evaluate(t.a)) self.assertEqual(6, self.evaluate(t.b)) # Make sure that the legacy saveable name has been applied. self.assertEqual("foo", saveable_compat.get_saveable_name(converter))
def _add_attributes_to_object_graph_for_saveable_objects( checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] object_to_save = _get_mapped_trackable(trackable, object_map) if saveables_cache is not None: cached_attributes = saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: name = factory_data.name key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if isinstance(saveable, python_state.PythonStateSaveable): if feed_additions is None: assert saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) # Update the object proto. # For updated Trackables that override serialize_to_tensors, add an # attribute for each tensor that is serialized. # For Trackables that have SaveableObjects or a legacy saveable name, # add a single attribute to the proto. if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and saveable_compat.get_saveable_name(object_to_save) is None): for local_name, local_key in ( saveables[0].get_proto_names_and_checkpoint_keys()): object_proto.attributes.add( name=local_name, checkpoint_key=local_key, full_name=_get_full_name(object_to_save)) else: object_proto.attributes.add( name=name, checkpoint_key=key, full_name=_get_full_name(object_to_save)) return named_saveable_objects, feed_additions