def _add_attributes_to_object_graph_for_saveable_objects( self, checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] if self._saveables_cache is not None: object_to_save = _get_mapped_trackable(trackable, object_map) cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: attribute = object_proto.attributes.add() attribute.name = name = factory_data.name attribute.checkpoint_key = key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) return named_saveable_objects, feed_additions
def _add_attributes_to_object_graph(self, trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for checkpoint_id, (trackable, object_proto) in enumerate( zip(trackable_objects, object_graph_proto.nodes)): assert node_ids[trackable] == checkpoint_id object_name = object_names[trackable] if object_map is None: object_to_save = trackable else: object_to_save = object_map.get(trackable, trackable) if self._saveables_cache is not None: cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for name, saveable_factory in ( object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) if cached_attributes is None: saveables = None else: saveables = cached_attributes.get(name, None) if saveables is not None: for saveable in saveables: if attribute.checkpoint_key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, attribute.checkpoint_key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if attribute.checkpoint_key not in saveable.name: raise AssertionError(( "The object %s produced a SaveableObject with name '%s' for " "attribute '%s'. Expected a name containing '%s'." ) % (trackable, name, saveable.name, attribute.checkpoint_key)) if cached_attributes is not None: cached_attributes[name] = saveables optional_restore = None for saveable in saveables: if optional_restore is None: optional_restore = saveable.optional_restore else: optional_restore = optional_restore and saveable.optional_restore if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError(( "The object %s tried to feed a value for the Tensor %s " "when saving, but another object is already feeding a " "value.") % (trackable, new_feed_key)) feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) if optional_restore is None: optional_restore = False attribute.optional_restore = optional_restore return named_saveable_objects, feed_additions
def _add_attributes_to_object_graph_for_saveable_objects( checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] object_to_save = _get_mapped_trackable(trackable, object_map) if saveables_cache is not None: cached_attributes = saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: name = factory_data.name key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if isinstance(saveable, python_state.PythonStateSaveable): if feed_additions is None: assert saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) # Update the object proto. # For updated Trackables that override serialize_to_tensors, add an # attribute for each tensor that is serialized. # For Trackables that have SaveableObjects or a legacy saveable name, # add a single attribute to the proto. if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and saveable_compat.get_saveable_name(object_to_save) is None): for local_name, local_key in ( saveables[0].get_proto_names_and_checkpoint_keys()): object_proto.attributes.add( name=local_name, checkpoint_key=local_key, full_name=_get_full_name(object_to_save)) else: object_proto.attributes.add( name=name, checkpoint_key=key, full_name=_get_full_name(object_to_save)) return named_saveable_objects, feed_additions