def init(self, graph_def=None, init_op=None, clear_devices=False, default_graph_signature=None, named_graph_signatures=None, assets=None, assets_callback=None): """Initialization. Args: graph_def: A GraphDef message of the graph to be used in inference. GraphDef of default graph is used when None. init_op: Op to be used in initialization. clear_devices: If device info of the graph should be cleared upon export. default_graph_signature: Default signature of the graph. named_graph_signatures: Map of named input/output signatures of the graph. assets: A list of tuples of asset files with the first element being the filename (string) and the second being the Tensor. assets_callback: callback with a single string argument; called during export with the asset path. Raises: RuntimeError: if init is called more than once. TypeError: if init_op is not an Operation or None. """ # Avoid Dangerous default value [] if named_graph_signatures is None: named_graph_signatures = {} if assets is None: assets = {} if self._has_init: raise RuntimeError("init should be called only once") self._has_init = True if graph_def or clear_devices: copy = tf.GraphDef() if graph_def: copy.CopyFrom(graph_def) else: copy.CopyFrom(tf.get_default_graph().as_graph_def()) if clear_devices: for node in copy.node: node.device = "" graph_any_buf = any_pb2.Any() graph_any_buf.Pack(copy) tf.add_to_collection(GRAPH_KEY, graph_any_buf) if init_op: if not isinstance(init_op, ops.Operation): raise TypeError("init_op needs to be an Operation: %s" % init_op) tf.add_to_collection(INIT_OP_KEY, init_op) signatures_proto = manifest_pb2.Signatures() if default_graph_signature: signatures_proto.default_signature.CopyFrom(default_graph_signature) for signature_name, signature in six.iteritems(named_graph_signatures): signatures_proto.named_signatures[signature_name].CopyFrom(signature) signatures_any_buf = any_pb2.Any() signatures_any_buf.Pack(signatures_proto) tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf) for filename, tensor in assets: asset = manifest_pb2.AssetFile() asset.filename = filename asset.tensor_binding.tensor_name = tensor.name asset_any_buf = any_pb2.Any() asset_any_buf.Pack(asset) tf.add_to_collection(ASSETS_KEY, asset_any_buf) self._assets_callback = assets_callback
def LoadSessionBundleFromPath(export_dir, target="", config=None): """Load session bundle from the given path. The function reads input from the export_dir, constructs the graph data to the default graph and restores the parameters for the session created. Args: export_dir: the directory that contains files exported by exporter. target: The execution engine to connect to. See target in tf.Session() config: A ConfigProto proto with configuration options. See config in tf.Session() Returns: session: a tensorflow session created from the variable files. meta_graph: a meta graph proto saved in the exporter directory. Raises: RuntimeError: if the required files are missing or contain unrecognizable fields, i.e. the exported model is invalid. """ meta_graph_filename = os.path.join(export_dir, constants.META_GRAPH_DEF_FILENAME) if not gfile.Exists(meta_graph_filename): raise RuntimeError("Expected meta graph file missing %s" % meta_graph_filename) variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME) if not gfile.Exists(variables_filename): variables_filename = os.path.join( export_dir, constants.VARIABLES_FILENAME_PATTERN) if not gfile.Glob(variables_filename): raise RuntimeError("Expected variables file missing %s" % variables_filename) assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY) # Reads meta graph file. meta_graph_def = meta_graph_pb2.MetaGraphDef() with gfile.GFile(meta_graph_filename) as f: meta_graph_def.ParseFromString(f.read()) collection_def = meta_graph_def.collection_def graph_def = tf.GraphDef() if constants.GRAPH_KEY in collection_def: # Use serving graph_def in MetaGraphDef collection_def if exists graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value if len(graph_def_any) != 1: raise RuntimeError( "Expected exactly one serving GraphDef in : %s" % meta_graph_def) else: graph_def_any[0].Unpack(graph_def) # Replace the graph def in meta graph proto. meta_graph_def.graph_def.CopyFrom(graph_def) tf.reset_default_graph() sess = tf.Session(target, graph=None, config=config) # Import the graph. saver = tf.train.import_meta_graph(meta_graph_def) # Restore the session. saver.restore(sess, variables_filename) init_op_tensor = None if constants.INIT_OP_KEY in collection_def: init_ops = collection_def[constants.INIT_OP_KEY].node_list.value if len(init_ops) != 1: raise RuntimeError( "Expected exactly one serving init op in : %s" % meta_graph_def) init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0] # Create asset input tensor list. asset_tensor_dict = {} if constants.ASSETS_KEY in collection_def: assets_any = collection_def[constants.ASSETS_KEY].any_list.value for asset in assets_any: asset_pb = manifest_pb2.AssetFile() asset.Unpack(asset_pb) asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join( assets_dir, asset_pb.filename) if init_op_tensor: # Run the init op. sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict) return sess, meta_graph_def
def doBasicsOneExportPath(self, export_path, clear_devices=False, global_step=GLOBAL_STEP): # Build a graph with 2 parameter nodes on different devices. tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: # v2 is an unsaved variable derived from v0 and v1. It is used to # exercise the ability to run an init op when restoring a graph. with sess.graph.device("/cpu:0"): v0 = tf.Variable(10, name="v0") with sess.graph.device("/cpu:1"): v1 = tf.Variable(20, name="v1") v2 = tf.Variable(1, name="v2", trainable=False, collections=[]) assign_v2 = tf.assign(v2, tf.add(v0, v1)) init_op = tf.group(assign_v2, name="init_op") tf.add_to_collection("v", v0) tf.add_to_collection("v", v1) tf.add_to_collection("v", v2) global_step_tensor = tf.Variable(global_step, name="global_step") named_tensor_bindings = { "logical_input_A": v0, "logical_input_B": v1 } signatures = { "foo": exporter.regression_signature(input_tensor=v0, output_tensor=v1), "generic": exporter.generic_signature(named_tensor_bindings) } def write_asset(path): file_path = os.path.join(path, "file.txt") with gfile.FastGFile(file_path, "w") as f: f.write("your data here") asset_file = tf.Variable("hello42.txt", name="filename42") assets = {("hello42.txt", asset_file)} tf.initialize_all_variables().run() # Run an export. save = tf.train.Saver({ "v0": v0, "v1": v1 }, restore_sequentially=True, sharded=True) export = exporter.Exporter(save) export.init( sess.graph.as_graph_def(), init_op=init_op, clear_devices=clear_devices, default_graph_signature=exporter.classification_signature( input_tensor=v0), named_graph_signatures=signatures, assets=assets, assets_callback=write_asset) export.export(export_path, global_step_tensor, sess, exports_to_keep=gc.largest_export_versions(2)) # Restore graph. compare_def = tf.get_default_graph().as_graph_def() tf.reset_default_graph() with tf.Session(target="", config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: save = tf.train.import_meta_graph( os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER % global_step, exporter.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = tf.GraphDef() graph_def_any[0].Unpack(graph_def) if clear_devices: for node in compare_def.node: node.device = "" self.assertProtoEquals(compare_def, graph_def) # Validate init_op. init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. signatures_any = collection_def[ exporter.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) default_signature = signatures.default_signature self.assertEqual( default_signature.classification_signature.input.tensor_name, "v0:0") bindings = signatures.named_signatures[ "generic"].generic_signature.map self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") read_foo_signature = ( signatures.named_signatures["foo"].regression_signature) self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. assets_any = collection_def[exporter.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join( export_path, exporter.VERSION_FORMAT_SPECIFIER % global_step, exporter.ASSETS_DIRECTORY, "file.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) # Validate graph restoration. save.restore( sess, os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER % global_step, exporter.VARIABLES_DIRECTORY)) self.assertEqual(10, tf.get_collection("v")[0].eval()) self.assertEqual(20, tf.get_collection("v")[1].eval()) tf.get_collection(exporter.INIT_OP_KEY)[0].run() self.assertEqual(30, tf.get_collection("v")[2].eval())
def init(self, graph_def=None, init_op=None, clear_devices=False, default_graph_signature=None, named_graph_signatures=None, assets_collection=None, assets_callback=gfile_copy_callback): """Initialization. Args: graph_def: A GraphDef message of the graph to be used in inference. GraphDef of default graph is used when None. init_op: Op to be used in initialization. clear_devices: If device info of the graph should be cleared upon export. default_graph_signature: Default signature of the graph. named_graph_signatures: Map of named input/output signatures of the graph. assets_collection: A collection of constant asset filepath tensors. If set the assets will be exported into the asset directory. assets_callback: callback with two argument called during export with the list of files to copy and the asset path. Raises: RuntimeError: if init is called more than once. TypeError: if init_op is not an Operation or None. ValueError: if asset file path tensors are not non-empty constant string scalar tensors. """ # Avoid Dangerous default value [] if named_graph_signatures is None: named_graph_signatures = {} assets = [] if assets_collection: for asset_tensor in assets_collection: asset_filepath = self._file_path_value(asset_tensor) if not asset_filepath: raise ValueError("invalid asset filepath tensor %s" % asset_tensor) basename = os.path.basename(asset_filepath) assets.append((basename, asset_tensor)) self._assets_to_copy[asset_filepath] = basename if self._has_init: raise RuntimeError("init should be called only once") self._has_init = True if graph_def or clear_devices: copy = tf.GraphDef() if graph_def: copy.CopyFrom(graph_def) else: copy.CopyFrom(tf.get_default_graph().as_graph_def()) if clear_devices: for node in copy.node: node.device = "" graph_any_buf = any_pb2.Any() graph_any_buf.Pack(copy) tf.add_to_collection(constants.GRAPH_KEY, graph_any_buf) if init_op: if not isinstance(init_op, ops.Operation): raise TypeError("init_op needs to be an Operation: %s" % init_op) tf.add_to_collection(constants.INIT_OP_KEY, init_op) signatures_proto = manifest_pb2.Signatures() if default_graph_signature: signatures_proto.default_signature.CopyFrom(default_graph_signature) for signature_name, signature in six.iteritems(named_graph_signatures): signatures_proto.named_signatures[signature_name].CopyFrom(signature) signatures_any_buf = any_pb2.Any() signatures_any_buf.Pack(signatures_proto) tf.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf) for filename, tensor in assets: asset = manifest_pb2.AssetFile() asset.filename = filename asset.tensor_binding.tensor_name = tensor.name asset_any_buf = any_pb2.Any() asset_any_buf.Pack(asset) tf.add_to_collection(constants.ASSETS_KEY, asset_any_buf) self._assets_callback = assets_callback