def get_metagraph(): """Constructs and returns a MetaGraphDef from the input file.""" if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() if FLAGS.metagraphdef.endswith(".pbtxt"): text_format.Merge(meta_file.read(), metagraph) else: metagraph.ParseFromString(meta_file.read()) if FLAGS.fetch is not None: fetch_collection = meta_graph_pb2.CollectionDef() for fetch in FLAGS.fetch.split(","): fetch_collection.node_list.value.append(fetch) metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() if FLAGS.graphdef.endswith(".pbtxt"): text_format.Merge(graph_file.read(), graph_def) else: graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() for fetch in FLAGS.fetch.split(","): fetch_op = graph.get_operation_by_name(fetch) graph.add_to_collection("train_op", fetch_op) metagraph = saver.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph) return metagraph
def testDefaultAttrsRemoved(self): producer_op_list = op_def_pb2.OpList() text_format.Merge(""" op { name: 'OpWithFutureDefaultAttr' attr { name: 'default_int' type: 'int' default_value { i: 456 } } } """, producer_op_list) # Attr only in producer_op_list with default value gets removed. with ops.Graph().as_default(): a = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'OpWithFutureDefaultAttr' attr { key: 'default_int' value { i: 456 } } } """), return_elements=["A"], producer_op_list=producer_op_list) with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"): a[0].get_attr("default_int") # Attr only in producer_op_list with non-default value is preserved. with ops.Graph().as_default(): a = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'OpWithFutureDefaultAttr' attr { key: 'default_int' value { i: 987 } } } """), return_elements=["A"], producer_op_list=producer_op_list) self.assertEqual(987, a[0].get_attr("default_int"))
def testInvalidInputForInputMap(self): with ops.Graph().as_default(): with self.assertRaises(TypeError) as e: importer.import_graph_def( self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)]) self.assertEqual("input_map must be a dictionary mapping strings to " "Tensor objects.", str(e.exception)) graph_def = self._MakeGraphDef(""" node { name: 'a' op: 'Placeholder' attr { key: 'dtype' value { type: DT_FLOAT } }} node { name: 'id' op: 'Identity' input: 'a:0' attr { key: 'T' value { type: DT_FLOAT } }}""") with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( graph_def, input_map={"a:0": variables.Variable(5.0)}, name="") self.assertStartsWith(str(e.exception), "tf.import_graph_def() requires a non-empty `name` " "if `input_map` contains non-Tensor values.") with ops.Graph().as_default(): t, = importer.import_graph_def( graph_def, input_map={"a:0": constant_op.constant(5.0)}, name="", return_elements=["id:0"]) with self.test_session(): self.assertEqual(5.0, t.eval())
def testWithDeviceFunctionDependingOnInputs(self): if ops._USE_C_API: return # TODO(skyewm): make this work with C API with ops.Graph().as_default() as g: with ops.device("/job:ps"): v1 = constant_op.constant(1.0) v2 = constant_op.constant(1.0) _ = v1 + v2 _ = v1 - v2 _ = array_ops.identity(v1) gdef = g.as_graph_def() # We'll use the following device function to observe ops with two inputs. ops_with_two_inputs = [] def InputCounter(op): if len(op.inputs) == 2: ops_with_two_inputs.append(op) return "" with ops.Graph().as_default() as g: with ops.device(InputCounter): importer.import_graph_def(gdef) # We expect to see the add and subtract, but not identity. self.assertEqual(2, len(ops_with_two_inputs))
def testImportGraphWithFunctionTwice(self): g = ops.Graph() with g.as_default(): @function.Defun() def Add2(x, y): return math_ops.add(x, y) x = array_ops.placeholder(dtype=dtypes.float32, name="x") y = array_ops.placeholder(dtype=dtypes.float32, name="y") _ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg gdef = g.as_graph_def() x = random_ops.random_uniform(dtype=dtypes.float32, shape=()) y = random_ops.random_uniform(dtype=dtypes.float32, shape=()) input_map = {"x:0": x, "y:0": y} with ops.name_scope("first"): z1 = importer.import_graph_def(gdef, return_elements=["z:0"], input_map=input_map)[0] with ops.name_scope("second"): z2 = importer.import_graph_def(gdef, return_elements=["z:0"], input_map=input_map)[0] with self.test_session() as sess: z1_val, z2_val = sess.run((z1, z2)) self.assertAllEqual(z1_val, z2_val)
def run_graph_def(graph_def, input_map, outputs): graph = ops_lib.Graph() with graph.as_default(): importer.import_graph_def(graph_def, input_map={}, name="") with session.Session(graph=graph) as sess: results = sess.run(outputs, feed_dict=input_map) return results
def testNamePrefixColocationAttrsMultipleImport(self): if ops._USE_C_API: return # TODO(skyewm): set uniquify_names original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' } node { name: 'B' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } }""") with ops.Graph().as_default(): b, = importer.import_graph_def( original_graph_def, return_elements=["B"], name="") _, = importer.import_graph_def( original_graph_def, return_elements=["B"], name="") self.assertProtoEqualsVersion(""" node { name: 'A' op: 'None' } node { name: 'B' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } } node { name: 'A_1' op: 'None' } node { name: 'B_1' op: 'None' attr { key: '_class' value { list { s: 'loc:@A_1' } } } }""", b.graph.as_graph_def())
def main(_): if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() metagraph.ParseFromString(meta_file.read()) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() if FLAGS.graphdef.endswith(".pbtxt"): text_format.Merge(graph_file.read(), graph_def) else: graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() fetch = graph.get_operation_by_name(FLAGS.fetch) graph.add_to_collection("train_op", fetch) metagraph = saver.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph) if FLAGS.rewriter_config is not None: rewriter_config = rewriter_config_pb2.RewriterConfig() text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report)
def testMissingControlInputInGraphDef(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, r"Node 'B': Unknown input node '\^A'"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'None' input: '^A' } """))
def testDuplicateOperationNames(self): with self.assertRaisesRegexp(ValueError, "Node 'A' is not unique"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'IntOutput' } node { name: 'B' op: 'IntOutput' } node { name: 'A' op: 'IntOutput' } """))
def testMissingInputOpInGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'If' input: 'A:0' } """)) self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
def testMissingControlInputInGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'None' input: '^A' } """)) self.assertTrue("Control input '^A' not found" in str(e.exception))
def testMissingInputOpInGraphDef(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "Node 'B': Unknown input node 'A:0'"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'FloatInput' input: 'A:0' } """))
def testInvalidTensorNameInGraphDef(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "Node 'B': Unknown input node 'A:B:0'"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'None' input: 'A:B:0' } """))
def testVersionLow(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( Exception, r"GraphDef producer version -1 below min producer %d supported " r"by TensorFlow \S+\. Please regenerate your graph.$" % versions.GRAPH_DEF_VERSION_MIN_PRODUCER): importer.import_graph_def(self._MakeGraphDef("", producer=-1))
def SetProducerVersion(self, graph, producer_version): # The C API doesn't expose altering GraphDefVersions. We can indirectly set # it via import_graph_def though. graph_def = graph_pb2.GraphDef() graph_def.versions.producer = producer_version with graph.as_default(): importer.import_graph_def(graph_def) assert graph.graph_def_versions.producer, producer_version
def testVersionHigh(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, r"GraphDef min consumer version %d above current version %d " r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" % (1 << 30, versions.GRAPH_DEF_VERSION)): importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
def _TestTrtGraphConverter(self, input_saved_model_dir=None, output_saved_model_dir=None, need_calibration=False, is_dynamic_op=False): """General method to test trt_convert.TrtGraphConverter().""" output_graph_def = self._ConvertGraph( input_saved_model_dir=input_saved_model_dir, output_saved_model_dir=output_saved_model_dir, need_calibration=need_calibration, is_dynamic_op=is_dynamic_op, use_function_backup=need_calibration) graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: if context.executing_eagerly(): root = load.load(output_saved_model_dir) saved_model_graph_def = root.signatures[ signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def() else: saved_model_graph_def = saved_model_utils.get_meta_graph_def( output_saved_model_dir, tag_constants.SERVING).graph_def self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = {node.name: node.op for node in graph_def.node} if context.executing_eagerly(): # In V2 the actual graph could be inside a function. for func in graph_def.library.function: node_name_to_op.update({node.name: node.op for node in func.node_def}) self.assertIn("TRTEngineOp_0", node_name_to_op) self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) else: self.assertEqual({ "input": "Placeholder", "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) # Run the calibrated graph. # TODO(laigd): consider having some input where the answer is different. with ops.Graph().as_default(): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual((test_data + 1.0)**2, sess.run( "output:0", feed_dict={"input:0": [[[test_data]]]}))
def testMissingInputMap(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), input_map={"B:0": constant_op.constant(5.0)}) self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
def testInvalidTensorNameInGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'None' input: 'A:B:0' } """)) self.assertEqual("Cannot convert 'A:B:0' to a tensor name.", str(e.exception))
def testMissingReturnOperation(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, "Requested return node 'B' not found in graph def"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), return_elements=["B"])
def testNamePrefixColocationAttrsNotFound(self): original_graph_def = self._MakeGraphDef(""" node { name: 'B' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } }""") with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "does not exist during import"): importer.import_graph_def( original_graph_def, return_elements=["B"], name="imported_graph")
def testDuplicateOperationNames(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } node { name: 'B' op: 'Oi' } node { name: 'A' op: 'Oi' } """)) self.assertEqual("Duplicate name 'A' in GraphDef.", str(e.exception))
def testMissingReturnOperation(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), return_elements=["B"]) self.assertTrue( "return_element 'B' not found in graph_def." in str(e.exception))
def testInvalidInputForReturnOperations(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( TypeError, "return_elements must be a list of strings."): importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7]) with self.assertRaisesRegexp(ValueError, "Cannot convert 'a:b:c' to a tensor name."): importer.import_graph_def( self._MakeGraphDef(""), return_elements=["a:b:c"])
def testInvalidSignatureNotEnoughInputsInGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } node { name: 'B' op: 'Iif' input: 'A:0' } """)) self.assertTrue("Input types mismatch (expected 'int32, float32' but " "got 'int32')" in str(e.exception))
def _convert_graph_def(self): """Convert the input GraphDef.""" graph = ops.Graph() with graph.as_default(): importer.import_graph_def(self._input_graph_def, name="") self._grappler_meta_graph_def = saver.export_meta_graph( graph_def=graph.as_graph_def(add_shapes=True), graph=graph) self._add_nodes_blacklist() self._run_conversion()
def testMissingInputMap(self): with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, r"Attempted to map inputs that were not found in graph_def: \[B:0\]"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), input_map={"B:0": constant_op.constant(5.0)})
def testMissingControlInputInGraphDef(self): if ops._USE_C_API: return # TODO(skyewm): make this work with C API with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'B' op: 'None' input: '^A' } """)) self.assertTrue("Control input '^A' not found" in str(e.exception))
def testInvalidSignatureTooManyInputsInGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(ValueError) as e: importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } node { name: 'B' op: 'None' input: 'A:0' } """)) self.assertTrue("More inputs specified ('A:0') than the op expects" in str(e.exception))
def testOldGraph(self): # Load graph generated from earlier version of TF where # placeholder shape was not set. # # a = tf.placeholder(tf.float32) # b = a + 1.0 # # Older graph's default shape is 'shape {}', not 'shape { # unknown_rank: true }' graph = """ node { name: "Placeholder" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } } node { name: "add/y" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { } float_val: 1.0 } } } } node { name: "add" op: "Add" input: "Placeholder" input: "add/y" attr { key: "T" value { type: DT_FLOAT } } } versions { producer: 21 } """ gdef = graph_pb2.GraphDef() text_format.Merge(graph, gdef) with self.test_session(): p, ret = importer.import_graph_def( gdef, return_elements=["Placeholder:0", "add:0"]) # Feed in a vector of two elements. Since the producer version # of 21, a shape of {} is interpreted as "any shape". If # producer version were 22, then we'd get a shape mismatch # error. self.assertAllEqual([2.0, 3.0], ret.eval(feed_dict={p: [1.0, 2.0]}))
def testInvalidInputForGraphDef(self): with ops.Graph().as_default(): with self.assertRaises(TypeError) as e: importer.import_graph_def("") self.assertEqual("graph_def must be a GraphDef proto.", str(e.exception))
def _imports_graph_def(): importer.import_graph_def(graph_def, name="")
def _TestTrtGraphConverter(self, input_saved_model_dir=None, output_saved_model_dir=None, need_calibration=False, is_dynamic_op=False): """General method to test trt_convert.TrtGraphConverter().""" output_graph_def = self._ConvertGraph( input_saved_model_dir=input_saved_model_dir, output_saved_model_dir=output_saved_model_dir, need_calibration=need_calibration, is_dynamic_op=is_dynamic_op, use_function_backup=need_calibration) graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: if context.executing_eagerly(): root = load.load(output_saved_model_dir) saved_model_graph_def = root.signatures[ _SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def() else: saved_model_graph_def = saved_model_utils.get_meta_graph_def( output_saved_model_dir, tag_constants.SERVING).graph_def self.assertTrue( isinstance(saved_model_graph_def, graph_pb2.GraphDef)) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = {node.name: node.op for node in graph_def.node} if context.executing_eagerly(): # In V2 the actual graph could be inside a function. for func in graph_def.library.function: node_name_to_op.update( {node.name: node.op for node in func.node_def}) self.assertIn("TRTEngineOp_0", node_name_to_op) self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) else: self.assertEqual( { "input": "Placeholder", "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) # Run the calibrated graph. # TODO(laigd): consider having some input where the answer is different. with ops.Graph().as_default(): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual( (test_data + 1.0)**2, sess.run( "output:0", feed_dict={"input:0": [[[test_data]]]}))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" # Unused by updated loading code. del restore_op_name, filename_tensor_name if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: continue var_list[key] = tensor # """ Print ops name def _node_name(n): if n.startswith("^"): return n[1:] else: return n.split(":")[0] name_to_node_map = {} # Keyed by node name. for node in input_graph_def.node: n = _node_name(node.name) name_to_node_map[n] = node # print(name_to_node_map.keys()) # """ saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) # variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def save(self, output_saved_model_dir): """Save the converted graph as a SavedModel. Args: output_saved_model_dir: construct a SavedModel using the converted GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None in __init__(). Raises: ValueError: if the input to the converter is a GraphDef instead of a SavedModel. """ assert self._converted if self._need_calibration: assert self._calibration_data_collected if self._input_graph_def: raise ValueError( "Not able to save to a SavedModel since input is a GraphDef") def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): """Restores collections that we need to keep.""" scope = "" for key in collection_keys: collection_def = src_meta_graph_def.collection_def[key] kind = collection_def.WhichOneof("kind") if kind is None: tf_logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) # It is assumed that there are no Variables Keys in collections for value in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) try: new_value = from_proto(proto, import_scope=scope) except: continue dest_graph.add_to_collection(key, new_value) else: field = getattr(collection_def, kind) if kind == "node_list": for value in field.value: name = ops.prepend_name_scope(value, scope) # Since the graph has been optimized, the node may no longer # exists try: col_op = dest_graph.as_graph_element(name) except (TypeError, ValueError, KeyError) as e: continue dest_graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the # fact that Python2 distinguishes between int and long, while # Python3 has only int. for value in field.value: dest_graph.add_to_collection(key, int(value)) else: for value in field.value: dest_graph.add_to_collection(key, ops.prepend_name_scope(value, scope)) # Write the transformed graphdef as SavedModel. saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) with ops.Graph().as_default(): importer.import_graph_def(self._converted_graph_def, name="") _restore_collections( ops.get_default_graph(), self._grappler_meta_graph_def, self._collections_to_keep( self._grappler_meta_graph_def.collection_def)) # We don't use any specific converter here. with session.Session(config=self._session_config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, self._input_saved_model_tags, signature_def_map=self._grappler_meta_graph_def.signature_def) # Ignore other meta graphs from the input SavedModel. saved_model_builder.save()
def testColocationWithDeviceFn(self): original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } } node { name: 'B' op: 'None' attr { key: '_class' value { list { s: 'loc:@A' } } } }""") # A device function that places "A" on one device and "B" on # another device. Because B is colocated with A, we test that B's # device function is overridden by A. def CustomDeviceFn(op): if "A" in op.name: return "/device:A:0" else: return "/device:B:0" with ops.Graph().as_default(): with ops.device(CustomDeviceFn): a, b = importer.import_graph_def(original_graph_def, return_elements=["A", "B"], name="imported_graph") self.assertEqual(a.device, "/device:A:0") self.assertEqual(b.device, "/device:A:0") self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) # Test a scenario where 'A' doesn't get a device; 'A' should not have a # device, but during runtime will get colocated with 'B' because of the # colocation attribute. B's device function is still overridden by A. def BDeviceFn(op): if "B" in op.name: return "/device:B:0" return "" with ops.Graph().as_default(): with ops.device(BDeviceFn): a, b = importer.import_graph_def(original_graph_def, return_elements=["A", "B"], name="imported_graph") self.assertEqual(a.device, "") self.assertEqual(b.device, "") self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) # Only A gets a device, so B inherits it implicitly. def ADeviceFn(op): if "A" in op.name: return "/device:A:0" return "" with ops.Graph().as_default(): with ops.device(ADeviceFn): a, b = importer.import_graph_def(original_graph_def, return_elements=["A", "B"], name="imported_graph") self.assertEqual(a.device, "/device:A:0") self.assertEqual(b.device, "/device:A:0") self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): raise ValueError("Input checkpoint '" + input_checkpoint + "' doesn't exist!") if not output_node_names: raise ValueError( "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: raise ValueError( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): raise ValueError( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice." ) return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True)): """Recreates a `Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and ( not input_map or sorted([compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError("Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value if not input_map or v not in input_map])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) importer.import_graph_def( input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto( proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def testFunctions(self): dtype = dtypes.float32 @function.Defun(dtype, dtype, dtype, dtype) def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument # Return the inputs for simplicity of testing. The correct return value # would be (dout1 + dout2, dout1 - dout2) return x, y @function.Defun(dtype, dtype, grad_func=Grad) def FuncWithGrad(x, y): return x + y, x - y @function.Defun(dtypes.int32) def ExternalTensorFunc(x): # c must be defined in the containing graph return x + c @function.Defun(dtypes.int32, dtypes.int32) def OuterFunc(x, y): @function.Defun(dtypes.int32) def InnerFunc(x): return x + x return InnerFunc(x) + y # Create graph with function calls and export to GraphDef with ops.Graph().as_default() as g1: p1 = array_ops.placeholder(dtype, name="p1") p2 = array_ops.placeholder(dtype, name="p2") # pylint: disable=unexpected-keyword-arg a, b = FuncWithGrad(p1, p2, name="f") c = constant_op.constant(10, dtype=dtypes.int32) ExternalTensorFunc(1, name="external") OuterFunc(10, 1, name="outer") # pylint: enable=unexpected-keyword-arg gdef = g1.as_graph_def() # Import GraphDef into new graph, add imported gradients, and test that # imported functions can be run with ops.Graph().as_default() as g2: p1, p2, a, b = importer.import_graph_def( gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="") grad = gradients_impl.gradients([a], [p1, p2]) with self.test_session(graph=g2) as sess: feed_dict = {p1: 1, p2: 2} a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict) self.assertEqual(a_val, 3.0) self.assertEqual(b_val, -1.0) # Grad function returns inputs values for testing self.assertEqual(grad_val, [1.0, 2.0]) self.assertEqual(sess.run("external:0"), 11) self.assertEqual(sess.run("outer:0"), 21) # Export the new graph and reimport to test that imported functions can be # successfully exported/imported again gdef = g2.as_graph_def() with ops.Graph().as_default() as g3: p1, p2, a, b = importer.import_graph_def( gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="") # Create new gradient functions (in additional to the imported gradient # functions created in g2). grad = gradients_impl.gradients([a], [p1, p2]) with self.test_session(graph=g3) as sess: feed_dict = {p1: 1, p2: 2} a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict) self.assertEqual(a_val, 3.0) self.assertEqual(b_val, -1.0) self.assertEqual(grad_val, [1.0, 2.0]) self.assertEqual(sess.run("external:0"), 11) self.assertEqual(sess.run("outer:0"), 21)
def TestFunc(): return importer.import_graph_def(gdef, return_elements=["z:0"])[0]
def calibrate(self, fetch_names, num_runs, feed_dict_fn=None, input_map_fn=None): """Run the calibration and return the calibrated GraphDef. Args: fetch_names: a list of output tensor name to fetch during calibration. num_runs: number of runs of the graph during calibration. feed_dict_fn: a function that returns a dictionary mapping input names (as strings) in the GraphDef to be calibrated to values (e.g. Python list, numpy arrays, etc). One and only one of `feed_dict_fn` and `input_map_fn` should be specified. input_map_fn: a function that returns a dictionary mapping input names (as strings) in the GraphDef to be calibrated to Tensor objects. The values of the named input tensors in the GraphDef to be calibrated will be re-mapped to the respective `Tensor` values during calibration. One and only one of `feed_dict_fn` and `input_map_fn` should be specified. Raises: ValueError: if the input combination is invalid. RuntimeError: if this method is called in eager mode. Returns: The GraphDef after the calibration. """ assert self._converted assert self._need_calibration assert not self._calibration_data_collected if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and not input_map_fn): raise ValueError( "Should specify one and only one of feed_dict_fn and input_map_fn.") if input_map_fn: for k, v in input_map_fn().items(): if not isinstance(k, str): raise ValueError("Keys of input_map_fn must be of type str") if not isinstance(v, tf.Tensor): raise ValueError("Values of input_map_fn must be of type tf.Tensor") self._calibration_graph = ops.Graph() with self._calibration_graph.as_default(): fetches = importer.import_graph_def( self._converted_graph_def, input_map=input_map_fn() if input_map_fn else None, return_elements=fetch_names, name="") with session.Session( graph=self._calibration_graph, config=self._session_config) as calibration_sess: for _ in range(num_runs): calibration_sess.run( fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None) # Maps device name to the corresponding get_calibration_data. # # TODO(laigd): a better way would be to use calibration_sess to list # all the devices, add one get_calibration_data for each device, and # fetch each such op for every resource until its found. This can work # even when the device of the TRTEngineOp is empty or not fully specified. device_to_get_resource_op_map = {} with self._calibration_graph.as_default(): resource_name_input = array_ops.placeholder(dtypes.string) for node in self._converted_graph_def.node: if node.op == _TRT_ENGINE_OP_NAME: # Adds the get_calibration_data op for the device if not done # before. We only add one such op for each device. # TODO(laigd): What if the device is empty????? if node.device not in device_to_get_resource_op_map: with self._calibration_graph.device(node.device): serialized_resources_output = ( gen_trt_ops.get_calibration_data_op(resource_name_input)) device_to_get_resource_op_map[node.device] = ( serialized_resources_output) # Get the calibration resource. calibration_result = calibration_sess.run( device_to_get_resource_op_map[node.device], feed_dict={ resource_name_input: _get_canonical_engine_name(node.name) }) node.attr["calibration_data"].s = calibration_result self._calibration_data_collected = True return self._converted_graph_def
def _test_variable_to_const_conversion(self, use_resource): with ops.Graph().as_default(): with variable_scope.variable_scope("", use_resource=use_resource): variable_node = variable_scope.get_variable("variable_node", initializer=1.0) another_variable = variable_scope.get_variable( "unused_variable_node", initializer=1.0) output_node = math_ops_lib.multiply(variable_node, 2.0, name="output_node") with session.Session() as sess: self.evaluate(variable_node.initializer) output = self.evaluate(output_node) self.assertNear(2.0, output, 0.00001) variable_graph_def = sess.graph.as_graph_def() # First get the constant_graph_def when variable_names_whitelist is # set, note that if variable_names_whitelist is not set an error will # be thrown because unused_variable_node is not initialized. constant_graph_def = graph_util.convert_variables_to_constants( sess, variable_graph_def, ["output_node"], variable_names_whitelist=set(["variable_node"])) # Then initialize the unused variable, and get another # constant_graph_def when variable_names_whitelist is not set. self.evaluate(another_variable.initializer) constant_graph_def_without_variable_whitelist = ( graph_util.convert_variables_to_constants( sess, variable_graph_def, ["output_node"])) # The unused variable should be cleared so the two graphs should be # equivalent. self.assertEqual( str(constant_graph_def), str(constant_graph_def_without_variable_whitelist)) # Test variable name black list. This should result in the variable # not being a const. constant_graph_def_with_blacklist = ( graph_util.convert_variables_to_constants( sess, variable_graph_def, ["output_node"], variable_names_blacklist=set(["variable_node"]))) variable_node = None for node in constant_graph_def_with_blacklist.node: if node.name == "variable_node": variable_node = node self.assertIsNotNone(variable_node) if use_resource: self.assertEqual(variable_node.op, "VarHandleOp") else: self.assertEqual(variable_node.op, "VariableV2") # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. with ops.Graph().as_default(): _ = importer.import_graph_def(constant_graph_def, name="") self.assertEqual(4, len(constant_graph_def.node)) self._ensure_no_variables_in_graph(constant_graph_def) with session.Session() as sess: output_node = sess.graph.get_tensor_by_name("output_node:0") output = self.evaluate(output_node) self.assertNear(2.0, output, 0.00001)
from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys from tensorflow.core.framework import graph_pb2 from tensorflow.python.client import session from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.summary import summary model_dir = 'D:/data/glomeruli/20180202_glomeruli_detection_noquant.pb' log_dir = 'd:/temp/tf' with session.Session(graph=ops.Graph()) as sess: with gfile.FastGFile(model_dir, "rb") as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) importer.import_graph_def(graph_def) # pb_visual_writer = summary.FileWriter(log_dir) # pb_visual_writer.add_graph(sess.graph) file_writer = summary.FileWriter(log_dir, sess.graph) print("Model Imported. Visualize by running: tensorboard --logdir={}". format(log_dir))
def import_scoped_meta_graph_with_return_elements( meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs", restore_collections_predicate=(lambda key: True), return_elements=None): """Imports graph from `MetaGraphDef` and returns vars and return elements. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates the desired collections, and returns a dictionary of all the Variables imported into the name scope. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. restore_collections_predicate: a predicate on collection names. A collection named c (i.e whose key is c) will be restored iff 1) `restore_collections_predicate(c)` is True, and 2) `c != unbound_inputs_col_name`. return_elements: A list of strings containing operation names in the `MetaGraphDef` that will be returned as `Operation` objects; and/or tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. Returns: A tuple of ( dictionary of all the `Variables` imported into the name scope, list of `Operation` or `Tensor` objects from the `return_elements` list). Raises: ValueError: If the graph_def contains unbound inputs. """ if context.executing_eagerly(): raise ValueError( "Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join( compat.as_str(v) for v in field.value if not input_map or v not in input_map)) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" scope_to_prepend_to_names = graph.unique_name(import_scope or "", mark_as_used=False) imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, producer_op_list=producer_op_list, return_elements=return_elements) # TensorFlow versions before 1.9 (not inclusive) exported SavedModels # without a VariableDef.trainable field set. tf_version = meta_graph_def.meta_info_def.tensorflow_version if not tf_version: variables_have_trainable = True else: variables_have_trainable = ( distutils_version.LooseVersion(tf_version) >= distutils_version.LooseVersion("1.9")) # Sort collections so we see TRAINABLE_VARIABLES first and can default these # variables to trainable if the value is not set in their VariableDef. sorted_collections = [] if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES, meta_graph_def.collection_def[ ops.GraphKeys.TRAINABLE_VARIABLES])) for key, value in sorted(meta_graph_def.collection_def.items()): if key != ops.GraphKeys.TRAINABLE_VARIABLES: sorted_collections.append((key, value)) # Restores all the other collections. variable_objects = {} for key, col_def in sorted_collections: # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue if not restore_collections_predicate(key): continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) # Temporary change to allow the TFMA evaluator to read metric variables # saved as a bytes list. # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. if key == ops.GraphKeys.METRIC_VARIABLES: # Metric variables will use the same proto functions as GLOBAL_VARIABLES from_proto = ops.get_from_proto_function( ops.GraphKeys.GLOBAL_VARIABLES) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access for value in col_def.bytes_list.value: variable = variable_objects.get(value, None) if variable is None: proto = proto_type() proto.ParseFromString(value) if not variables_have_trainable: # If the VariableDef proto does not contain a "trainable" # property because it was exported before that property was # added, we default it to whether the variable is in the # TRAINABLE_VARIABLES collection. We've sorted # TRAINABLE_VARIABLES to be first, so trainable variables will # be created from that collection. proto.trainable = ( key == ops.GraphKeys.TRAINABLE_VARIABLES) variable = from_proto( proto, import_scope=scope_to_prepend_to_names) variable_objects[value] = variable graph.add_to_collection(key, variable) else: for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: logging.warning( "The saved meta_graph is possibly from an older release:\n" "'%s' collection should be of type 'byte_list', but instead " "is of type '%s'.", key, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, scope_to_prepend_to_names)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=scope_to_prepend_to_names) for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v return var_list, imported_return_elements
def _TestTrtGraphConverter(self, device, output_saved_model_dir=None, need_calibration=False, is_dynamic_op=False): """General method to test trt_convert.TrtGraphConverter().""" output_graph_def = self._ConvertGraphV1( output_saved_model_dir=output_saved_model_dir, need_calibration=need_calibration, is_dynamic_op=is_dynamic_op, device=device) graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: saved_model_graph_def = saved_model_utils.get_meta_graph_def( output_saved_model_dir, tag_constants.SERVING).graph_def self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = { self._MayRemoveGraphSequenceNumber(node.name): node.op for node in graph_def.node } if device is not None and device.startswith("/CPU:"): self.assertEqual( { "add": "AddV2", "v1": "Const", "add_1": "AddV2", "add_2": "AddV2", "input1": "Placeholder", "input2": "Placeholder", "mul": "Mul", "output": "Identity" }, node_name_to_op) else: self.assertEqual( { "input1": "Placeholder", "input2": "Placeholder", "TRTEngineOp_000": "TRTEngineOp", "output": "Identity" }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] if device is not None and device.startswith("/CPU:"): self.assertEmpty(trt_engine_nodes) return self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) # Run the calibrated graph. # TODO(laigd): consider having some input where the answer is different. with ops.Graph().as_default(): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual( (test_data + 1.0)**2 + test_data, sess.run("output:0", feed_dict={ "input1:0": [[[test_data]]], "input2:0": [[[test_data]]] }))
def quantize(saved_model_path: str, signature_keys=None, tags=None, output_directory=None, representative_dataset=None): """Quantizes the given SavedModel. Args: saved_model_path: Path to the saved model. When representative_dataset is not provided, this should be a model trained with QAT. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. output_directory: The path to save the output SavedModel (must be an empty directory). representative_dataset: a generator that returns a dictionary in {input_name: input_tensor} format or a tuple with signature key and a dictionary in {input_name: input_tensor} format that feeds calibration data for quantizing model. This should be provided when the model is not a QAT model. Returns: A SavedModel object with TF quantization applied. Raises: ValueError: when representative_dataset is not provided for non-QAT model. """ if tags is None: tags = set([tag_constants.SERVING]) if signature_keys is None: signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] is_qat_saved_model = _is_qat_saved_model(saved_model_path) signatures = _get_signatures_from_saved_model(saved_model_path, signature_keys, tags) # Checks if the model is from QAT if representative_dataset is None and not is_qat_saved_model: raise ValueError( 'When `representative_dataset` is not provided, the model should be ' 'trained with quantization-aware training (QAT).') if is_qat_saved_model: # Handle QAT models are supported. graph_def_serialized = ( quantize_model_wrapper.quantize_qat_model(saved_model_path, ','.join(signature_keys), ','.join(tags))) else: # Handle PTQ models are supported with mocking calibration. graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_pre_calibration( saved_model_path, ','.join(signature_keys), ','.join(tags))) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) float_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(float_model_dir) with session.Session(graph=ops.Graph()) as sess: for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii') importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() float_model = saved_model_load(float_model_dir) for sample in representative_dataset(): # TODO(b/214311251): Add a test case with multiple signatures. if isinstance(sample, tuple): if not isinstance(sample[1], dict): raise ValueError('You need to provide a dictionary with input ' 'names and values in the second argument in the ' 'tuple') signature_key = sample[0] input_data_map = sample[1] elif isinstance(sample, dict): if len(signature_keys) > 1: raise ValueError('When the model has multiple signatures, you need ' 'to provide a tuple with signature key and a ' 'dictionary with input names and values') signature_key = signature_keys[0] input_data_map = sample else: raise ValueError('You need to provide either a dictionary with input ' 'names and values or a tuple with signature key and a ' 'dictionary with input names and values') func = float_model.signatures[signature_key] func(**input_data_map) for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_id = node_def.attr['id'].s min_val = quantize_model_wrapper.get_min_from_calibrator(node_id) max_val = quantize_model_wrapper.get_max_from_calibrator(node_id) quantize_model_wrapper.clear_data_from_calibrator(node_id) node_def.attr['min'].f = float(min_val) node_def.attr['max'].f = float(max_val) calibrated_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(calibrated_model_dir) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() signatures = _get_signatures_from_saved_model(calibrated_model_dir, signature_keys, tags) graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_post_calibration( calibrated_model_dir, ','.join(signature_keys), ','.join(tags), )) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) if output_directory is None: output_directory = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(output_directory) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError("The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() return saved_model_load(output_directory)
def _testFreezeGraph(self, saver_write_version): checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") checkpoint_state_name = "checkpoint_state" input_graph_name = "input_graph.pb" output_graph_name = "output_graph.pb" # We'll create an input graph that has a single variable containing 1.0, # and that then multiplies it by 2. with ops.Graph().as_default(): variable_node = variables.VariableV1(1.0, name="variable_node") output_node = math_ops.multiply(variable_node, 2.0, name="output_node") sess = session.Session() init = variables.global_variables_initializer() sess.run(init) output = sess.run(output_node) self.assertNear(2.0, output, 0.00001) saver = saver_lib.Saver(write_version=saver_write_version) checkpoint_path = saver.save( sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name) graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) # We save out the graph to disk, and then call the const conversion # routine. input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) input_saver_def_path = "" input_binary = False output_node_names = "output_node" restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) clear_devices = False freeze_graph.freeze_graph( input_graph_path, input_saver_def_path, input_binary, checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_graph_path, clear_devices, "", "", "", checkpoint_version=saver_write_version) # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. with ops.Graph().as_default(): output_graph_def = graph_pb2.GraphDef() with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = importer.import_graph_def(output_graph_def, name="") self.assertEqual(4, len(output_graph_def.node)) for node in output_graph_def.node: self.assertNotEqual("VariableV2", node.op) self.assertNotEqual("Variable", node.op) with session.Session() as sess: output_node = sess.graph.get_tensor_by_name("output_node:0") output = sess.run(output_node) self.assertNear(2.0, output, 0.00001)
def function_def_to_graph(fdef, input_shapes=None): """Converts a FunctionDef to a FuncGraph (sub-class Graph). The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. The input tensors are represented as placeholders. Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set by the caller. Args: fdef: FunctionDef. input_shapes: Optional. A list of TensorShape objects of the shapes of function inputs. Defaults to the function's "_input_shapes" attribute. If specified, its length must match length of `fdef.signature.input_arg`. If a shape is None, the corresponding input placeholder will have unknown shape. Returns: A FuncGraph. """ func_graph = FuncGraph(fdef.signature.name) if input_shapes is None: input_shapes_attr = fdef.attr.get("_input_shapes", None) if input_shapes_attr is not None: input_shapes = input_shapes_attr.list.shape graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( fdef, input_shapes) with func_graph.as_default(): # Add all function nodes to the graph. importer.import_graph_def(graph_def, name="") # Initialize fields specific to FuncGraph. # inputs input_tensor_names = [ nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg ] func_graph.inputs = [ func_graph.get_tensor_by_name(name) for name in input_tensor_names ] # outputs output_tensor_names = [ nested_to_flat_tensor_name[fdef.ret[arg.name]] for arg in fdef.signature.output_arg ] func_graph.outputs = [ func_graph.get_tensor_by_name(name) for name in output_tensor_names ] func_graph.control_outputs = [ func_graph.get_operation_by_name(fdef.control_ret[ret_name]) for ret_name in fdef.signature.control_output ] for node in graph_def.node: output_shapes = node.attr.get("_output_shapes", None) if output_shapes is not None: op = func_graph.get_operation_by_name(node.name) for output_index, shape in enumerate(output_shapes.list.shape): op.outputs[output_index].set_shape(shape) return func_graph
import os import tensorflow as tf from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import importer os.environ['CUDA_VISIBLE_DEVICES'] = '0' pb_path = 'mymodel.pb' run_meta = tf.RunMetadata() with tf.Graph().as_default(): output_graph_def = graph_pb2.GraphDef() with open(pb_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = importer.import_graph_def(output_graph_def, name="") print('model loaded!') all_keys = sorted( [n.name for n in tf.get_default_graph().as_graph_def().node]) # for k in all_keys: # print(k) with tf.Session() as sess: flops = tf.profiler.profile( tf.get_default_graph(), run_meta=run_meta, options=tf.profiler.ProfileOptionBuilder.float_operation()) print("test flops:{:,}".format(flops.total_float_ops))
def from_frozen_graph(cls, graph_def_file, input_arrays, output_arrays, input_shapes=None): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: graph_def_file: Full filepath of file containing TensorFlow GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) Returns: TocoConverter class. Raises: ValueError: Unable to parse input file. The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. """ with _session.Session() as sess: sess.run(global_variables_initializer()) # Read GraphDef from file. graph_def = _graph_pb2.GraphDef() with open(graph_def_file, "rb") as f: file_content = f.read() try: graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): try: print("Ignore 'tcmalloc: large alloc' warnings.") if not isinstance(file_content, str): if PY3: file_content = file_content.decode('utf-8') else: file_content = file_content.encode('utf-8') _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): raise ValueError("Unable to parse input file '{}'.".format( graph_def_file)) sess.graph.as_default() import_graph_def(graph_def, name="") # Get input and output tensors. input_tensors = get_tensors_from_tensor_names( sess.graph, input_arrays) output_tensors = get_tensors_from_tensor_names( sess.graph, output_arrays) set_tensor_shapes(input_tensors, input_shapes) # Check if graph is frozen. if not _is_frozen_graph(sess): raise ValueError( "Please freeze the graph using freeze_graph.py.") # Create TocoConverter class. return cls(sess.graph_def, input_tensors, output_tensors)
def testInvalidInputForReturnOperations(self): with ops.Graph().as_default(): with self.assertRaises(TypeError) as e: importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7]) self.assertEqual("return_elements must be a list of strings.", str(e.exception))
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) #if not tf.gfile.Exists(FLAGS.output_dir): # tf.gfile.MkDir(FLAGS.output_dir) if FLAGS.tensorrt: gpu_options = None print(trt.trt_convert.get_linked_tensorrt_version()) gpu_options = cpb2.GPUOptions( per_process_gpu_memory_fraction=_GPU_MEM_FRACTION) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) else: sessconfig = None # Instantiate video capture object. cap = cv2.VideoCapture(1) # Set resolution # if resolution is not None: x_length, y_length = (1024, 1280) cap.set(3, x_length) # 3 and 4 are OpenCV property IDs. cap.set(4, y_length) cap.read() x_new = int(cap.get(3)) y_new = int(cap.get(4)) print('Resolution is: {0} by {1}'.format(x_new, y_new)) with tf.Graph().as_default(), tf.Session(config=sessconfig) as sess: #TODO - calculate these dimensions dynamically (they can't use None since TensorRT # needs precalculated dimensions # Defines place holder for the style image. style_img_ph = tf.placeholder(tf.float32, shape=[200, 1200, 3], name="style_img_ph") if FLAGS.style_square_crop: style_img_preprocessed = image_utils.center_crop_resize_image( style_img_ph, FLAGS.style_image_size) else: style_img_preprocessed = image_utils.resize_image( style_img_ph, FLAGS.style_image_size) # Defines place holder for the content image. content_img_ph = tf.placeholder(tf.float32, shape=[200, 1200, 3], name="content_img_ph") if FLAGS.content_square_crop: content_img_preprocessed = image_utils.center_crop_resize_image( content_img_ph, FLAGS.image_size) else: content_img_preprocessed = image_utils.resize_image( content_img_ph, FLAGS.image_size) # Defines the model. stylized_images, _, _, bottleneck_feat = build_model.build_model( content_img_preprocessed, style_img_preprocessed, trainable=False, is_training=False, inception_end_point='Mixed_6e', style_prediction_bottleneck=100, adds_losses=False) print(stylized_images) print(bottleneck_feat) if tf.gfile.IsDirectory(FLAGS.checkpoint): checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint) else: checkpoint = FLAGS.checkpoint tf.logging.info( 'loading latest checkpoint file: {}'.format(checkpoint)) init_fn = slim.assign_from_checkpoint_fn( checkpoint, slim.get_variables_to_restore()) sess.run([tf.local_variables_initializer()]) init_fn(sess) tf.train.write_graph(sess.graph_def, '.', 'model.pbtxt') if FLAGS.tensorrt: # We use a built-in TF helper to export variables to constants output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights tf.get_default_graph().as_graph_def( ), # The graph_def is used to retrieve the nodes [ 'transformer/expand/conv3/conv/Sigmoid' ] # The output node names are used to select the usefull nodes ) trt_graph = trt.create_inference_graph( input_graph_def=output_graph_def, outputs=["transformer/expand/conv3/conv/Sigmoid"], max_workspace_size_bytes=5 << 30, max_batch_size=1, precision_mode= "FP16", # TRT Engine precision "FP32","FP16" or "INT8" minimum_segment_size=10) bottleneck_feat_O, content_img_ph_O, stylized_images_O = importer.import_graph_def( graph_def=trt_graph, return_elements=[ "Conv/BiasAdd", "content_img_ph", "transformer/expand/conv3/conv/Sigmoid" ]) bottleneck_feat_O = bottleneck_feat_O.outputs[0] content_img_ph_O = content_img_ph_O.outputs[0] stylized_images_O = stylized_images_O.outputs[0] print("bottleneck opt:" + str(bottleneck_feat_O)) print(content_img_ph_O) print(stylized_images_O) # Gets the list of the input style images. #style_img_list = tf.gfile.Glob(FLAGS.style_images_paths) # if len(style_img_list) > FLAGS.maximum_styles_to_evaluate: # np.random.seed(1234) # style_img_list = np.random.permutation(style_img_list) # style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate] # Gets list of input co ntent images. # content_img_list = tf.gfile.Glob(FLAGS.content_images_paths) # if style_i % 10 == 0: # tf.logging.info('Stylizing %s with (%d) %s' % # ( content_img_name, style_i, # style_img_name)) # for style_i, style_img_path in enumerate(style_img_list): # if style_i > FLAGS.maximum_styles_to_evaluate: # break interpolation_weight = FLAGS.interpolation_weight activate_style = None while True: start = timer() #calculating style isn't the major FPS bottleneck current_style = Style.objects.filter(is_active=True).first() if (activate_style != current_style): activate_style = current_style style_img_path = activate_style.source_file.path print("current image is " + style_img_path) style_img_name = "bricks" style_image_np = image_utils.load_np_image_uint8( style_img_path)[:, :, :3] style_image_np = cv2.resize(style_image_np, (1200, 200)) # Saves preprocessed style image. style_img_croped_resized_np = sess.run( style_img_preprocessed, feed_dict={style_img_ph: style_image_np}) #image_utils.save_np_image(style_img_croped_resized_np, # os.path.join(FLAGS.output_dir, # '%s.jpg' % (style_img_name))) # Computes bottleneck features of the style prediction network for the # given style image. style_params = sess.run( bottleneck_feat, feed_dict={style_img_ph: style_image_np}) # for content_i, content_img_path in enumerate(content_img_list): ret, frame = cap.read() print("webcam image: " + str(frame.shape)) #crop to get the weird 1200x200 format content_img_np = frame[500:700, 80:1280] #content_img_np = frame print("cropped image:" + str(content_img_np.shape)) # content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, : # 3] # content_img_name = os.path.basename(content_img_path)[:-4] content_img_name = "webcam" # Saves preprocessed content image. print("Input image:" + str(content_img_np.shape)) inp_img_croped_resized_np = sess.run( content_img_preprocessed, feed_dict={content_img_ph: content_img_np}) # image_utils.save_np_image(inp_img_croped_resized_np, # os.path.join(FLAGS.output_dir, # '%s.jpg' % (content_img_name))) # Computes bottleneck features of the style prediction network for the # identity transform. identity_params = sess.run( bottleneck_feat, feed_dict={style_img_ph: content_img_np}) # Interpolates between the parameters of the identity transform and # style parameters of the given style image. wi = interpolation_weight style_np = identity_params * (1 - wi) + style_params * wi if FLAGS.tensorrt: style_np = np.reshape(style_np, (1, 100, 1, 1)) stylized_image_res = sess.run(stylized_images_O, feed_dict={ bottleneck_feat_O: style_np, content_img_ph_O: content_img_np }) else: stylized_image_res = sess.run(stylized_images, feed_dict={ bottleneck_feat: style_np, content_img_ph: content_img_np }) end = timer() print(end - start) print(stylized_image_res.shape) # Saves stylized image. # image_utils.save_np_image( # stylized_image_res, # os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' % # (content_img_name, style_img_name, interp_i))) display_np_image(stylized_image_res, FLAGS.showFullScreen) print(stylized_image_res.shape) # if cv2.waitKey(1) & 0xFF == ord('q'): # break #img_out = np.squeeze(stylized_image_res).astype(np.uint8) #img_out = cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB) #cv2.imshow('frame', img_out) key = cv2.waitKey(10) print("Key " + str(key)) if key == 27: break elif key == 192: FLAGS.showFullScreen = False cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_NORMAL) elif (key == 233 or key == 193): FLAGS.showFullScreen = True cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) elif key == 60: # less interpolation_weight -= 0.25 elif key == 62: # > more interpolation_weight += 0.25 #if cv2.waitKey(1) & 0xFF == ord('q'): # break cap.release() cv2.destroyAllWindows()
def testStripUnusedMultipleInputs(self): input_graph_name = "input_graph.pb" output_graph_name = "output_graph.pb" # We'll create an input graph that multiplies two input nodes. with ops.Graph().as_default(): constant_node1 = constant_op.constant(1.0, name="constant_node1") constant_node2 = constant_op.constant(2.0, name="constant_node2") input_node1 = math_ops.sub(constant_node1, 3.0, name="input_node1") input_node2 = math_ops.sub(constant_node2, 5.0, name="input_node2") output_node = math_ops.multiply(input_node1, input_node2, name="output_node") math_ops.add(output_node, 2.0, name="later_node") sess = session.Session() output = sess.run(output_node) self.assertNear(6.0, output, 0.00001) graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) # We save out the graph to disk, and then call the const conversion # routine. input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) input_binary = False input_node_names = "input_node1,input_node2" input_node_types = [ dtypes.float32.as_datatype_enum, dtypes.float32.as_datatype_enum ] output_binary = True output_node_names = "output_node" output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) strip_unused_lib.strip_unused_from_files( input_graph_path, input_binary, output_graph_path, output_binary, input_node_names, output_node_names, input_node_types) # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. with ops.Graph().as_default(): output_graph_def = graph_pb2.GraphDef() with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = importer.import_graph_def(output_graph_def, name="") self.assertEqual(3, len(output_graph_def.node)) for node in output_graph_def.node: self.assertNotEqual("Add", node.op) self.assertNotEqual("Sub", node.op) if node.name == input_node_names: self.assertTrue("shape" in node.attr) with session.Session() as sess: input_node1 = sess.graph.get_tensor_by_name("input_node1:0") input_node2 = sess.graph.get_tensor_by_name("input_node2:0") output_node = sess.graph.get_tensor_by_name("output_node:0") output = sess.run(output_node, feed_dict={ input_node1: [10.0], input_node2: [-5.0] }) self.assertNear(-50.0, output, 0.00001)
def testEmptyGraph(self): with ops.Graph().as_default() as g: init_version = g.version importer.import_graph_def(self._MakeGraphDef("")) self.assertEqual(init_version, g.version)
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode=TrtPrecisionMode.FP32, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batches=None, use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, output_saved_model_dir=None, session_config=None): """Python wrapper for the TRT transformation. Args: input_graph_def: a GraphDef object containing a model to be transformed. If set to None, the graph will be read from the SavedModel loaded from input_saved_model_dir. outputs: list of tensors or node names for the model outputs. Only used when input_graph_def is not None. max_batch_size: max size for the input batch. max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). precision_mode: one of TrtPrecisionMode.supported_precision_modes(). minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT network and engine at run time. maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. If the number of cached engines is already at max but none of them can serve the input, the TRTEngineOp will fall back to run the TF function based on which the TRTEngineOp is created. cached_engine_batches: a list of batch sizes used to create cached engines, only used when is_dynamic_op is True. The length of the list should be <= maximum_cached_engines, and the dynamic TRT op will use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. use_calibration: this argument is ignored if precision_mode is not INT8. If set to True, a calibration graph will be created to calibrate the missing ranges. The calibration graph must be converted to an inference graph using calib_graph_to_infer_graph() after running calibration. if set to False, quantization nodes will be expected for every tensor in the graph (exlcuding those which will be fused). If a range is missing, an error will occur. Please note that accuracy may be negatively affected if there is a mismatch between which tensors TRT quantizes and which tensors were trained with fake quantization. input_saved_model_dir: the directory to load the SavedModel which contains the input graph to transforms. Used only when input_graph_def is None. input_saved_model_tags: list of tags to load the SavedModel. output_saved_model_dir: if not None, construct a SavedModel using the returned GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None. session_config: the ConfigProto used to create a Session. It's also used as a template to create a TRT-enabled ConfigProto for conversion. If not specified, a default ConfigProto will be used. Returns: A GraphDef transformed from input_graph_def (or the SavedModel graph def loaded from input_saved_model_dir, if input_graph_def is not present), where all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF function is added for each of the subgraphs. If is_dynamic_op is True, each TRTEngineOp will contain a serialized subgraph GraphDef, which will be converted to a TRT engine at execution time and the TRT engine will be cached for future usage. A new TRT engine will be created each time when none of the cached engines match the input shapes. If it fails to execute the TRT engine or the number of cached engines reaches maximum_cached_engines, the op will fall back to call the corresponding TF function. If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT engine created from the corresponding subgraph. No more engines will be created on the fly, and the op will fall back to call the corresponding TF function when it fails to execute the engine. Raises: ValueError: if the combination of the parameters is invalid. RuntimeError: if the TensorRT library version is incompatible. """ compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() version_mismatch = False if loaded_version[0] < compiled_version[0]: tf_logging.error( "TensorRT version mismatch. Tensorflow was compiled against " + "TensorRT %s but library loaded from environment is TensorRT %s" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" ) raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: tf_logging.warn("TensorRT mismatch. Compiled against version " + "%s, but loaded %s. Things may not work" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version]))) version_mismatch = True break if not version_mismatch: tf_logging.info("Running against TensorRT version %s" % ".".join([str(x) for x in loaded_version])) if session_config is None: session_config = config_pb2.ConfigProto() if input_saved_model_tags is None: input_saved_model_tags = [tag_constants.SERVING] saved_model_loader = None grappler_meta_graph_def = None if input_graph_def is None: # Read from SavedModel and freeze the graph if necessary. if input_saved_model_dir is None: raise ValueError( "input_graph_def and input_saved_model_dir cannot be " "both None") with ops.Graph().as_default(): with session.Session(config=session_config) as sess: saved_model_loader = loader_impl.SavedModelLoader( input_saved_model_dir) input_meta_graph_def = saved_model_loader.load( sess, input_saved_model_tags) output_node_names = set() def _gather_names(tensor_info): """Get the node names from a TensorInfo.""" return set([ tensor_info[key].name.split(":")[0] for key in tensor_info ]) # Get input and outputs from all SignatureDef. for key in input_meta_graph_def.signature_def: signature_def = input_meta_graph_def.signature_def[key] output_node_names.update( _gather_names(signature_def.inputs)) output_node_names.update( _gather_names(signature_def.outputs)) # Freeze the variables in the SavedModel graph and copy the frozen # graph over. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), list(output_node_names)) grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) # Copy the collections that are not variables. for key in input_meta_graph_def.collection_def: # TODO(laigd): currently we use the collection key to filter out # collections that depend on variable ops, but this may miss some # other user-defined collections. A better way would be to use # CollectionDef::NodeList for the filtering. if key not in [ "variables", "local_variables", "model_variables", "trainable_variables", "train_op", "table_initializer" ]: grappler_meta_graph_def.collection_def[key].CopyFrom( input_meta_graph_def.collection_def[key]) # Copy other information. grappler_meta_graph_def.meta_info_def.CopyFrom( input_meta_graph_def.meta_info_def) for key in input_meta_graph_def.signature_def: grappler_meta_graph_def.signature_def[key].CopyFrom( input_meta_graph_def.signature_def[key]) # TODO(laigd): maybe add back AssetFileDef. else: if output_saved_model_dir is not None: raise ValueError("output_saved_model_dir cannot be set when " "input_graph_def is set") # Create MetaGraphDef from input graph. graph = ops.Graph() with graph.as_default(): importer.import_graph_def(input_graph_def, name="") grappler_meta_graph_def = saver.export_meta_graph( graph_def=graph.as_graph_def(add_shapes=True), graph=graph) if outputs: output_collection = meta_graph_pb2.CollectionDef() output_list = output_collection.node_list.value for i in outputs: if isinstance(i, ops.Tensor): output_list.append(_to_bytes(i.name)) else: output_list.append(_to_bytes(i)) # TODO(laigd): use another key as the outputs are really not train_op. grappler_meta_graph_def.collection_def["train_op"].CopyFrom( output_collection) # Create TRT-enabled ConfigProto. session_config_with_trt = config_pb2.ConfigProto() session_config_with_trt.CopyFrom(session_config) rewriter_config = None if (session_config_with_trt.HasField("graph_options") and session_config_with_trt.graph_options.HasField("rewrite_options")): rewriter_config = session_config_with_trt.graph_options.rewrite_options rewriter_config_with_trt = get_tensorrt_rewriter_config( rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, cached_engine_batches, use_calibration) session_config_with_trt.graph_options.rewrite_options.CopyFrom( rewriter_config_with_trt) # Run Grappler. transformed_graph_def = tf_optimizer.OptimizeGraph(session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") # Optionally write the transformed graphdef as SavedModel. if output_saved_model_dir is not None: saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) with ops.Graph().as_default(): importer.import_graph_def(transformed_graph_def, name="") # We don't use TRT here. with session.Session(config=session_config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, input_saved_model_tags, signature_def_map=grappler_meta_graph_def.signature_def) # Ignore other meta graphs from the input SavedModel. saved_model_builder.save() return transformed_graph_def
def testMultipleImport(self): graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'IntOutput' } node { name: 'B' op: 'IntInput' input: 'A:0' } """) with ops.Graph().as_default(): # Initial import a, b = importer.import_graph_def( graph_def, return_elements=["A", "B"], name="") self.assertEqual(a.name, "A") self.assertEqual(b.name, "B") self.assertEqual(list(b.inputs), [a.outputs[0]]) # Repeat the same import a1, b1 = importer.import_graph_def( graph_def, return_elements=["A", "B"], name="") self.assertEqual(a1.name, "A_1") self.assertEqual(b1.name, "B_1") self.assertEqual(list(b1.inputs), [a1.outputs[0]]) # Repeat the same import again a2, b2 = importer.import_graph_def( graph_def, return_elements=["A", "B"], name="") self.assertEqual(a2.name, "A_2") self.assertEqual(b2.name, "B_2") self.assertEqual(list(b2.inputs), [a2.outputs[0]]) # Import with an already-used name a3, b3 = importer.import_graph_def( graph_def, return_elements=["A", "B"], name="A") self.assertEqual(a3.name, "A_3/A") self.assertEqual(b3.name, "A_3/B") self.assertEqual(list(b3.inputs), [a3.outputs[0]]) # Import with existing de-duped node names a1_1, b1_1 = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A_1' op: 'IntOutput' } node { name: 'B_1' op: 'IntInput' input: 'A_1:0' } """), return_elements=["A_1", "B_1"], name="") self.assertEqual(a1_1.name, "A_1_1") self.assertEqual(b1_1.name, "B_1_1") self.assertEqual(list(b1_1.inputs), [a1_1.outputs[0]]) # Create a name scope and then import node with same name with ops.name_scope("foo"): constant_op.constant(1) foo, = importer.import_graph_def( self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"), return_elements=["foo"], name="") self.assertEqual(foo.name, "foo_1") # Imported node name can't conflict with intermediate name scope (but can # conflict with outer scope and full name scope) with ops.name_scope("outer"): with ops.name_scope("inner"): c = constant_op.constant(1, name="c") self.assertEqual(c.op.name, "outer/inner/c") outer, inner, new_c, outer_inner, outer_inner_c = ( importer.import_graph_def( self._MakeGraphDef( "node { name: 'outer' op: 'IntOutput' }" "node { name: 'inner' op: 'IntOutput' }" "node { name: 'c' op: 'IntOutput' }" "node { name: 'outer/inner' op: 'IntOutput' }" "node { name: 'outer/inner/c' op: 'IntOutput' }"), return_elements=["outer", "inner", "c", "outer/inner", "outer/inner/c"], name="")) self.assertEqual(outer.name, "outer_1") self.assertEqual(inner.name, "inner") self.assertEqual(new_c.name, "c") self.assertEqual(outer_inner.name, "outer/inner_1") self.assertEqual(outer_inner_c.name, "outer/inner/c_1")