def testBasicMemory(self): """Make sure arguments can be passed correctly.""" with test_util.device(use_gpu=False): a = constant_op.constant(10, name="a") b = constant_op.constant(20, name="b") c = math_ops.add_n([a, b], name="c") d = math_ops.add_n([b, c], name="d") train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def( graph=ops.get_default_graph()) report = cost_analyzer.GenerateMemoryReport(mg) # Print the report to make it easier to debug print("{}".format(report)) # Check the report self.assertTrue( "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: " "16 bytes" in report) self.assertTrue(" a:0 uses 4 bytes" in report) self.assertTrue(" b:0 uses 4 bytes" in report) self.assertTrue(" c:0 uses 4 bytes" in report) self.assertTrue(" d:0 uses 4 bytes" in report)
def main(_): metagraph = get_metagraph() rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report, FLAGS.verbose) print(report) if FLAGS.memory_report: report = cost_analyzer.GenerateMemoryReport(metagraph) print(report)
def main(_): if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() if FLAGS.metagraphdef.endswith(".pbtxt"): text_format.Merge(meta_file.read(), metagraph) else: metagraph.ParseFromString(meta_file.read()) if FLAGS.fetch is not None: fetch_collection = meta_graph_pb2.CollectionDef() fetch_collection.node_list.value.append(FLAGS.fetch) metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() if FLAGS.graphdef.endswith(".pbtxt"): text_format.Merge(graph_file.read(), graph_def) else: graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() fetch = graph.get_operation_by_name(FLAGS.fetch) graph.add_to_collection("train_op", fetch) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(), graph=graph) rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report) report = cost_analyzer.GenerateMemoryReport(metagraph) print(report)