def _serialize_gathered_objects(self, checkpointable_objects, path_to_root, object_map=None): """Create SaveableObjects and protos for gathered objects.""" object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(checkpointable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, object_names=object_names) object_graph_proto = self._fill_object_graph_proto( checkpointable_objects=checkpointable_objects, node_ids=node_ids, slot_variables=slot_variables) named_saveable_objects, feed_additions = ( self._add_attributes_to_object_graph( checkpointable_objects=checkpointable_objects, object_graph_proto=object_graph_proto, node_ids=node_ids, object_names=object_names, object_map=object_map)) return named_saveable_objects, object_graph_proto, feed_additions
def __init__(self, checkpoint_view): self.checkpoint_view = checkpoint_view checkpointable_objects, node_ids, slot_variables = ( self.checkpoint_view.objects_ids_and_slot_variables()) self.nodes = checkpointable_objects self.node_ids = node_ids self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary( ) self.slot_variables = slot_variables self.concrete_functions = [] # Also add `Function`s as nodes. nodes_without_functions = list(self.nodes) seen_function_names = set() for node in nodes_without_functions: for function in checkpoint_view.list_functions(node).values(): if function not in self.node_ids: self.node_ids[function] = len(self.nodes) self.nodes.append(function) if isinstance(function, def_function.Function): # Force listing the concrete functions for the side effects: # - populate the cache for functions that have an input_signature # and have not been called. # - force side effects of creation of concrete functions, e.g. create # variables on first run. concrete_functions = ( function. _list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access else: concrete_functions = [function] for concrete_function in concrete_functions: if concrete_function.name not in seen_function_names: seen_function_names.add(concrete_function.name) self.concrete_functions.append(concrete_function)
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo(asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.TrackableResource): new_resource = obj.create_resource() resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id for concrete_function in self.concrete_functions: for capture in concrete_function.captured_inputs: if (isinstance(capture, ops.EagerTensor) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): copied_tensor = constant_op.constant(capture.numpy()) node_id = len(self.nodes) node = _CapturedConstant(eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(checkpointable_objects) slot_variables = object_identity.ObjectIdentityDictionary() for checkpointable in non_slot_objects: if (isinstance(checkpointable, optimizer_v1.Optimizer) # TODO(b/110718070): Fix Keras imports. or hasattr(checkpointable, "_create_or_restore_slot_variable")): naming_scheme = _slot_variable_naming_for_optimizer( optimizer_path=object_names[checkpointable]) slot_names = checkpointable.get_slot_names() for slot_name in slot_names: for original_variable_node_id, original_variable in enumerate( non_slot_objects): try: slot_variable = checkpointable.get_slot( original_variable, slot_name) except (AttributeError, KeyError): slot_variable = None if slot_variable is None: continue slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access # TODO(allenl): Gather dependencies of slot variables. raise NotImplementedError( "Currently only variables with no dependencies can be saved as " "slot variables. File a feature request if this limitation " "bothers you.") if slot_variable in node_ids: raise NotImplementedError( "A slot variable was re-used as a dependency of a " "Checkpointable object. This is not currently allowed. File a " "feature request if this limitation bothers you.") checkpoint_name = naming_scheme( variable_path=object_names[original_variable], slot_name=slot_name) object_names[slot_variable] = checkpoint_name slot_variable_node_id = len(checkpointable_objects) node_ids[slot_variable] = slot_variable_node_id checkpointable_objects.append(slot_variable) slot_variable_proto = ( checkpointable_object_graph_pb2. CheckpointableObjectGraph.CheckpointableObject. SlotVariableReference( slot_name=slot_name, original_variable_node_id=original_variable_node_id, slot_variable_node_id=slot_variable_node_id)) slot_variables.setdefault(checkpointable, []).append(slot_variable_proto) return slot_variables
def objects_ids_and_slot_variables(self): """Traverse the object graph and list all accessible objects. Looks for `Checkpointable` objects which are dependencies of `root_checkpointable`. Includes slot variables only if the variable they are slotting for and the optimizer are dependencies of `root_checkpointable` (i.e. if they would be saved with a checkpoint). Returns: A tuple of (checkpointable objects, object -> node id, slot variables) """ checkpointable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(checkpointable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, object_names=object_names) return checkpointable_objects, node_ids, slot_variables
def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) path_to_root = object_identity.ObjectIdentityDictionary() path_to_root[self.root] = () while to_visit: current_checkpointable = to_visit.popleft() if isinstance(current_checkpointable, tracking.NotCheckpointable): raise NotImplementedError( ("The object %s does not support object-based saving. File a " "feature request if this limitation bothers you. In the meantime, " "you can remove the dependency on this object and save everything " "else.") % (current_checkpointable,)) bfs_sorted.append(current_checkpointable) for name, dependency in self.list_dependencies(current_checkpointable): if dependency not in path_to_root: path_to_root[dependency] = ( path_to_root[current_checkpointable] + ( base.CheckpointableReference(name, dependency),)) to_visit.append(dependency) return bfs_sorted, path_to_root
def __init__(self, root): super(_AugmentedGraphView, self).__init__(root) # Object -> (name -> dep) self._extra_dependencies = object_identity.ObjectIdentityDictionary() self._functions = object_identity.ObjectIdentityDictionary()