def testAssets(self): export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("with-assets")) builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=tf.Graph()) as sess: v = tf.Variable(42, name="v") sess.run(tf.initialize_all_variables()) self.assertEqual(42, v.eval()) # Build an asset collection. asset_filepath = os.path.join( compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("hello42.txt")) file_io.write_string_to_file(asset_filepath, "foo bar baz") asset_file_tensor = tf.constant(asset_filepath, name="asset_file_tensor") tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file_tensor) ignored_filepath = os.path.join( compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("ignored.txt")) file_io.write_string_to_file(ignored_filepath, "will be ignored") asset_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) builder.add_meta_graph_and_variables( sess, ["foo"], assets_collection=asset_collection) # Save the SavedModel to disk. builder.save() with self.test_session(graph=tf.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) # Validate the assets. collection_def = foo_graph.collection_def assets_any = collection_def[constants.ASSETS_KEY].any_list.value self.assertEqual(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY), compat.as_bytes("hello42.txt")) asset_contents = file_io.read_file_to_string(assets_path) self.assertEqual("foo bar baz", compat.as_text(asset_contents)) self.assertEqual("hello42.txt", asset.filename) self.assertEqual("asset_file_tensor:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY), compat.as_bytes("ignored.txt")) self.assertFalse(file_io.file_exists(ignored_asset_path))
def _add_asset_to_collection(self, asset_filename, asset_tensor): """Builds an asset proto and adds it to the asset collection of the graph. Args: asset_filename: The filename of the asset to be added. asset_tensor: The asset tensor used to populate the tensor binding of the asset proto. """ asset_proto = manifest_pb2.AssetFile() asset_proto.filename = asset_filename asset_proto.tensor_binding.tensor_name = asset_tensor.name asset_any_proto = Any() asset_any_proto.Pack(asset_proto) ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
def doBasicsOneExportPath(self, export_path, clear_devices=False, global_step=GLOBAL_STEP, sharded=True): # Build a graph with 2 parameter nodes on different devices. tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: # v2 is an unsaved variable derived from v0 and v1. It is used to # exercise the ability to run an init op when restoring a graph. with sess.graph.device("/cpu:0"): v0 = tf.Variable(10, name="v0") with sess.graph.device("/cpu:1"): v1 = tf.Variable(20, name="v1") v2 = tf.Variable(1, name="v2", trainable=False, collections=[]) assign_v2 = tf.assign(v2, tf.add(v0, v1)) init_op = tf.group(assign_v2, name="init_op") tf.add_to_collection("v", v0) tf.add_to_collection("v", v1) tf.add_to_collection("v", v2) global_step_tensor = tf.Variable(global_step, name="global_step") named_tensor_bindings = { "logical_input_A": v0, "logical_input_B": v1 } signatures = { "foo": exporter.regression_signature(input_tensor=v0, output_tensor=v1), "generic": exporter.generic_signature(named_tensor_bindings) } asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt") asset_file = tf.constant(asset_filepath_orig, name="filename42") tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file) with gfile.FastGFile(asset_filepath_orig, "w") as f: f.write("your data here") assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt") with gfile.FastGFile(ignored_asset, "w") as f: f.write("additional data here") tf.initialize_all_variables().run() # Run an export. save = tf.train.Saver({ "v0": v0, "v1": v1 }, restore_sequentially=True, sharded=sharded) export = exporter.Exporter(save) export.init( sess.graph.as_graph_def(), init_op=init_op, clear_devices=clear_devices, default_graph_signature=exporter.classification_signature( input_tensor=v0), named_graph_signatures=signatures, assets_collection=assets_collection) export.export(export_path, global_step_tensor, sess, exports_to_keep=gc.largest_export_versions(2)) # Restore graph. compare_def = tf.get_default_graph().as_graph_def() tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: save = tf.train.import_meta_graph( os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = tf.GraphDef() graph_def_any[0].Unpack(graph_def) if clear_devices: for node in compare_def.node: node.device = "" self.assertProtoEquals(compare_def, graph_def) # Validate init_op. init_ops = collection_def[constants.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. signatures_any = collection_def[ constants.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) default_signature = signatures.default_signature self.assertEqual( default_signature.classification_signature.input.tensor_name, "v0:0") bindings = signatures.named_signatures[ "generic"].generic_signature.map self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") read_foo_signature = ( signatures.named_signatures["foo"].regression_signature) self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. assets_any = collection_def[constants.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "hello42.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "ignored.txt") self.assertFalse(gfile.Exists(ignored_asset_path)) # Validate graph restoration. if sharded: save.restore( sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME_PATTERN)) else: save.restore( sess, os.path.join( export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME)) self.assertEqual(10, tf.get_collection("v")[0].eval()) self.assertEqual(20, tf.get_collection("v")[1].eval()) tf.get_collection(constants.INIT_OP_KEY)[0].run() self.assertEqual(30, tf.get_collection("v")[2].eval())
def load_session_bundle_from_path(export_dir, target="", config=None): """Load session bundle from the given path. The function reads input from the export_dir, constructs the graph data to the default graph and restores the parameters for the session created. Args: export_dir: the directory that contains files exported by exporter. target: The execution engine to connect to. See target in tf.Session() config: A ConfigProto proto with configuration options. See config in tf.Session() Returns: session: a tensorflow session created from the variable files. meta_graph: a meta graph proto saved in the exporter directory. Raises: RuntimeError: if the required files are missing or contain unrecognizable fields, i.e. the exported model is invalid. """ meta_graph_filename = os.path.join(export_dir, constants.META_GRAPH_DEF_FILENAME) if not file_io.file_exists(meta_graph_filename): raise RuntimeError("Expected meta graph file missing %s" % meta_graph_filename) variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME) if not file_io.file_exists(variables_filename): variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME_PATTERN) if not file_io.get_matching_files(variables_filename): raise RuntimeError("Expected variables file missing %s" % variables_filename) assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY) # Reads meta graph file. meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.ParseFromString( file_io.read_file_to_string(meta_graph_filename)) collection_def = meta_graph_def.collection_def graph_def = tf.GraphDef() if constants.GRAPH_KEY in collection_def: # Use serving graph_def in MetaGraphDef collection_def if exists graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value if len(graph_def_any) != 1: raise RuntimeError( "Expected exactly one serving GraphDef in : %s" % meta_graph_def) else: graph_def_any[0].Unpack(graph_def) # Replace the graph def in meta graph proto. meta_graph_def.graph_def.CopyFrom(graph_def) tf.reset_default_graph() sess = tf.Session(target, graph=None, config=config) # Import the graph. saver = tf.train.import_meta_graph(meta_graph_def) # Restore the session. saver.restore(sess, variables_filename) init_op_tensor = None if constants.INIT_OP_KEY in collection_def: init_ops = collection_def[constants.INIT_OP_KEY].node_list.value if len(init_ops) != 1: raise RuntimeError("Expected exactly one serving init op in : %s" % meta_graph_def) init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0] # Create asset input tensor list. asset_tensor_dict = {} if constants.ASSETS_KEY in collection_def: assets_any = collection_def[constants.ASSETS_KEY].any_list.value for asset in assets_any: asset_pb = manifest_pb2.AssetFile() asset.Unpack(asset_pb) asset_tensor_dict[ asset_pb.tensor_binding.tensor_name] = os.path.join( assets_dir, asset_pb.filename) if init_op_tensor: # Run the init op. sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict) return sess, meta_graph_def
def init(self, graph_def=None, init_op=None, clear_devices=False, default_graph_signature=None, named_graph_signatures=None, assets_collection=None, assets_callback=gfile_copy_callback): """Initialization. Args: graph_def: A GraphDef message of the graph to be used in inference. GraphDef of default graph is used when None. init_op: Op to be used in initialization. clear_devices: If device info of the graph should be cleared upon export. default_graph_signature: Default signature of the graph. named_graph_signatures: Map of named input/output signatures of the graph. assets_collection: A collection of constant asset filepath tensors. If set the assets will be exported into the asset directory. assets_callback: callback with two argument called during export with the list of files to copy and the asset path. Raises: RuntimeError: if init is called more than once. TypeError: if init_op is not an Operation or None. ValueError: if asset file path tensors are not non-empty constant string scalar tensors. """ # Avoid Dangerous default value [] if named_graph_signatures is None: named_graph_signatures = {} assets = [] if assets_collection: for asset_tensor in assets_collection: asset_filepath = self._file_path_value(asset_tensor) if not asset_filepath: raise ValueError("invalid asset filepath tensor %s" % asset_tensor) basename = os.path.basename(asset_filepath) assets.append((basename, asset_tensor)) self._assets_to_copy[asset_filepath] = basename if self._has_init: raise RuntimeError("init should be called only once") self._has_init = True if graph_def or clear_devices: copy = graph_pb2.GraphDef() if graph_def: copy.CopyFrom(graph_def) else: copy.CopyFrom(ops.get_default_graph().as_graph_def()) if clear_devices: for node in copy.node: node.device = "" graph_any_buf = Any() graph_any_buf.Pack(copy) ops.add_to_collection(constants.GRAPH_KEY, graph_any_buf) if init_op: if not isinstance(init_op, ops.Operation): raise TypeError("init_op needs to be an Operation: %s" % init_op) ops.add_to_collection(constants.INIT_OP_KEY, init_op) signatures_proto = manifest_pb2.Signatures() if default_graph_signature: signatures_proto.default_signature.CopyFrom(default_graph_signature) for signature_name, signature in six.iteritems(named_graph_signatures): signatures_proto.named_signatures[signature_name].CopyFrom(signature) signatures_any_buf = Any() signatures_any_buf.Pack(signatures_proto) ops.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf) for filename, tensor in assets: asset = manifest_pb2.AssetFile() asset.filename = filename asset.tensor_binding.tensor_name = tensor.name asset_any_buf = Any() asset_any_buf.Pack(asset) ops.add_to_collection(constants.ASSETS_KEY, asset_any_buf) self._assets_callback = assets_callback
def load_session_bundle_from_path(export_dir, target="", config=None, meta_graph_def=None): """Load session bundle from the given path. The function reads input from the export_dir, constructs the graph data to the default graph and restores the parameters for the session created. Args: export_dir: the directory that contains files exported by exporter. target: The execution engine to connect to. See target in tf.Session() config: A ConfigProto proto with configuration options. See config in tf.Session() meta_graph_def: optional object of type MetaGraphDef. If this object is present, then it is used instead of parsing MetaGraphDef from export_dir. Returns: session: a tensorflow session created from the variable files. meta_graph: a meta graph proto saved in the exporter directory. Raises: RuntimeError: if the required files are missing or contain unrecognizable fields, i.e. the exported model is invalid. """ if not meta_graph_def: meta_graph_filename = os.path.join(export_dir, constants.META_GRAPH_DEF_FILENAME) if not file_io.file_exists(meta_graph_filename): raise RuntimeError("Expected meta graph file missing %s" % meta_graph_filename) # Reads meta graph file. meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.ParseFromString( file_io.read_file_to_string(meta_graph_filename, binary_mode=True)) variables_filename = "" variables_filename_list = [] checkpoint_sharded = False variables_index_filename = os.path.join( export_dir, constants.VARIABLES_INDEX_FILENAME_V2) checkpoint_v2 = file_io.file_exists(variables_index_filename) # Find matching checkpoint files. if checkpoint_v2: # The checkpoint is in v2 format. variables_filename_pattern = os.path.join( export_dir, constants.VARIABLES_FILENAME_PATTERN_V2) variables_filename_list = file_io.get_matching_files( variables_filename_pattern) checkpoint_sharded = True else: variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME) if file_io.file_exists(variables_filename): variables_filename_list = [variables_filename] else: variables_filename = os.path.join( export_dir, constants.VARIABLES_FILENAME_PATTERN) variables_filename_list = file_io.get_matching_files( variables_filename) checkpoint_sharded = True # Prepare the files to restore a session. if not variables_filename_list: restore_files = "" elif checkpoint_v2 or not checkpoint_sharded: # For checkpoint v2 or v1 with non-sharded files, use "export" to restore # the session. restore_files = constants.VARIABLES_FILENAME else: restore_files = constants.VARIABLES_FILENAME_PATTERN assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY) collection_def = meta_graph_def.collection_def graph_def = graph_pb2.GraphDef() if constants.GRAPH_KEY in collection_def: # Use serving graph_def in MetaGraphDef collection_def if exists graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value if len(graph_def_any) != 1: raise RuntimeError( "Expected exactly one serving GraphDef in : %s" % meta_graph_def) else: graph_def_any[0].Unpack(graph_def) # Replace the graph def in meta graph proto. meta_graph_def.graph_def.CopyFrom(graph_def) ops.reset_default_graph() sess = session.Session(target, graph=None, config=config) # Import the graph. saver = saver_lib.import_meta_graph(meta_graph_def) # Restore the session. if restore_files: saver.restore(sess, os.path.join(export_dir, restore_files)) init_op_tensor = None if constants.INIT_OP_KEY in collection_def: init_ops = collection_def[constants.INIT_OP_KEY].node_list.value if len(init_ops) != 1: raise RuntimeError("Expected exactly one serving init op in : %s" % meta_graph_def) init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0] # Create asset input tensor list. asset_tensor_dict = {} if constants.ASSETS_KEY in collection_def: assets_any = collection_def[constants.ASSETS_KEY].any_list.value for asset in assets_any: asset_pb = manifest_pb2.AssetFile() asset.Unpack(asset_pb) asset_tensor_dict[ asset_pb.tensor_binding.tensor_name] = os.path.join( assets_dir, asset_pb.filename) if init_op_tensor: # Run the init op. sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict) return sess, meta_graph_def