def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) node_ids = util.ObjectIdentityDictionary() for i, obj in enumerate(saveable_view.nodes): node_ids[obj] = i if resource_variable_ops.is_resource_variable(obj): node_ids[obj.handle] = i elif isinstance(obj, tracking.TrackableAsset): node_ids[obj.asset_path.handle] = i for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_graph(root, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() checkpointable_objects, node_ids, slot_variables = util.find_objects(root) util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables, proto) node_ids = util.ObjectIdentityDictionary() for i in range(len(checkpointable_objects)): obj = checkpointable_objects[i] node_ids[obj] = i if resource_variable_ops.is_resource_variable(obj): node_ids[obj.handle] = i elif isinstance(obj, tracking.TrackableAsset): node_ids[obj.asset_path.handle] = i for obj, obj_proto in zip(checkpointable_objects, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) function_serialization.add_polymorphic_functions_to_object_graph_proto( checkpointable_objects, proto, node_ids) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def __init__(self, root): checkpointable_objects, node_ids, slot_variables = util.find_objects( root) self.nodes = checkpointable_objects self.node_ids = node_ids self.slot_variables = slot_variables self.functions = util.ObjectIdentityDictionary() # Also add `Function`s as nodes. nodes_without_functions = list(self.nodes) for obj in nodes_without_functions: self.functions[obj] = self._list_functions(obj) for function in self.functions[obj].values(): if function not in self.node_ids: self.node_ids[function] = len(self.nodes) self.nodes.append(function) # Avoids recursing into functions to see if other functions are # assigned to attributes. This is sometimes true for concrete # functions but not helpful. self.functions[function] = {} if isinstance(function, def_function.Function): # Force listing the concrete functions for the side effects: # - populate the cache for functions that have an input_signature # and have not been called. # - force side effects of creation of concrete functions, e.g. create # variables on first run. function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = util.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.TrackableResource): new_resource = obj.create_resource() resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path.handle] = node_id for concrete_function in self.concrete_functions: for capture in concrete_function.captured_inputs: if (isinstance(capture, ops.EagerTensor) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): copied_tensor = constant_op.constant(capture.numpy()) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def __init__(self, root): checkpointable_objects, node_ids, slot_variables = util.find_objects(root) self.nodes = checkpointable_objects self.node_ids = node_ids self.slot_variables = slot_variables self.polymorphic_functions = util.ObjectIdentityDictionary() # Also add polymorphic functions as nodes. for obj in self.nodes: self.polymorphic_functions[obj] = self._list_polymorphic_functions(obj) for function in self.polymorphic_functions[obj].values(): if function not in self.node_ids: self.node_ids[function] = len(self.nodes) self.nodes.append(function) # Force listing the concrete functions for the side effects: # - populate the cache for polymorphic functions that have an # input_signature and have not been called. # - force side effects of creation of concrete functions, e.g. create # variables on first run. function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
def _map_resources(accessible_objects): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Args: accessible_objects: A list of objects, some of which may contain resources, to create replacements for. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = util.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo(asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for obj in accessible_objects: if isinstance(obj, tracking.TrackableResource): new_resource = obj.create_resource() resource_map[obj.resource_handle] = new_resource elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) return object_map, resource_map, asset_info