def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) node_ids = util.ObjectIdentityDictionary() for i, obj in enumerate(saveable_view.nodes): node_ids[obj] = i if resource_variable_ops.is_resource_variable(obj): node_ids[obj.handle] = i elif isinstance(obj, tracking.TrackableAsset): node_ids[obj.asset_path.handle] = i for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_graph(root, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() checkpointable_objects, node_ids, slot_variables = util.find_objects(root) util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables, proto) node_ids = util.ObjectIdentityDictionary() for i in range(len(checkpointable_objects)): obj = checkpointable_objects[i] node_ids[obj] = i if resource_variable_ops.is_resource_variable(obj): node_ids[obj.handle] = i elif isinstance(obj, tracking.TrackableAsset): node_ids[obj.asset_path.handle] = i for obj, obj_proto in zip(checkpointable_objects, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) function_serialization.add_polymorphic_functions_to_object_graph_proto( checkpointable_objects, proto, node_ids) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: serialized = function_serialization.serialize_concrete_function( concrete_function, saveable_view.captured_tensor_node_ids, coder) if serialized is not None: proto.concrete_functions[concrete_function.name].CopyFrom( serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_graph(obj, export_dir): """Save a SavedObjectGraph proto for `obj`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. object_proto = util.make_object_graph_without_attributes( obj, proto=saved_object_graph_pb2.SavedObjectGraph()) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, object_proto.SerializeToString())
def _write_object_graph(root, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() checkpointable_objects, node_ids, slot_variables = util.find_objects(root) util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables, proto) for obj, obj_proto in zip(checkpointable_objects, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join( extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def load(export_dir): """Load a SavedModel from `export_dir`.""" object_graph_filename = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY), compat.as_bytes("object_graph.pb")) if file_io.file_exists(object_graph_filename): # If there is an object graph associated with the SavedModel, we'll create a # root object from that. object_graph_string = file_io.FileIO(object_graph_filename, "rb").read() object_graph_proto = ( saved_object_graph_pb2.SavedObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) root = _recreate_object_graph(object_graph_proto) else: raise NotImplementedError( "Currently only SavedModels exported with `tf.saved_model.save` may be " "imported. Other SavedModels may eventually be supported via load().") # TODO(allenl): load functions from the SavedModel into the eager context return root