def _write_object_proto(obj, proto, asset_file_def_index, node_ids): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom( function_serialization.serialize_function(obj, node_ids)) elif isinstance(obj, defun.ConcreteFunction): proto.concrete_function.CopyFrom( function_serialization.serialize_concrete_function(obj, node_ids)) else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: serialized = function_serialization.serialize_concrete_function( concrete_function, saveable_view.captured_tensor_node_ids, coder) if serialized is not None: proto.concrete_functions[concrete_function.name].CopyFrom( serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join(extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_proto(obj, proto, asset_file_def_index, node_ids): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom( function_serialization.serialize_function(obj, node_ids)) elif isinstance(obj, defun.ConcreteFunction): proto.concrete_function.CopyFrom( function_serialization.serialize_concrete_function(obj, node_ids)) else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef(producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) node_ids = util.ObjectIdentityDictionary() for i, obj in enumerate(saveable_view.nodes): node_ids[obj] = i if resource_variable_ops.is_resource_variable(obj): node_ids[obj.handle] = i elif isinstance(obj, tracking.TrackableAsset): node_ids[obj.asset_path.handle] = i coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: serialized = function_serialization.serialize_concrete_function( concrete_function, node_ids, coder) if serialized is not None: proto.concrete_functions[concrete_function.name].CopyFrom( serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join( extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_graph(saveable_view, export_dir, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: serialized = function_serialization.serialize_concrete_function( concrete_function, saveable_view.captured_tensor_node_ids, coder) if serialized is not None: proto.concrete_functions[concrete_function.name].CopyFrom( serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) file_io.recursive_create_dir(extra_asset_dir) object_graph_filename = os.path.join( extra_asset_dir, compat.as_bytes("object_graph.pb")) file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _serialize_object_graph(saveable_view, asset_file_def_index): """Save a SavedObjectGraph proto for `root`.""" # SavedObjectGraph is similar to the TrackableObjectGraph proto in the # checkpoint. It will eventually go into the SavedModel. proto = saved_object_graph_pb2.SavedObjectGraph() saveable_view.fill_object_graph_proto(proto) coder = nested_structure_coder.StructureCoder() for concrete_function in saveable_view.concrete_functions: serialized = function_serialization.serialize_concrete_function( concrete_function, saveable_view.captured_tensor_node_ids, coder) if serialized is not None: proto.concrete_functions[concrete_function.name].CopyFrom( serialized) for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) return proto