コード例 #1
0
    def testBasicMemory(self):
        """Make sure arguments can be passed correctly."""
        with test_util.device(use_gpu=False):
            a = constant_op.constant(10, name="a")
            b = constant_op.constant(20, name="b")
            c = math_ops.add_n([a, b], name="c")
            d = math_ops.add_n([b, c], name="d")
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(d)
            mg = meta_graph.create_meta_graph_def(
                graph=ops.get_default_graph())

        report = cost_analyzer.GenerateMemoryReport(mg)

        # Print the report to make it easier to debug
        print("{}".format(report))

        # Check the report
        self.assertTrue(
            "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
            "16 bytes" in report)
        self.assertTrue("  a:0 uses 4 bytes" in report)
        self.assertTrue("  b:0 uses 4 bytes" in report)
        self.assertTrue("  c:0 uses 4 bytes" in report)
        self.assertTrue("  d:0 uses 4 bytes" in report)
コード例 #2
0
def main(_):
    metagraph = get_metagraph()
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    if FLAGS.rewriter_config is not None:
        text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

    report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report,
                                              FLAGS.verbose)
    print(report)
    if FLAGS.memory_report:
        report = cost_analyzer.GenerateMemoryReport(metagraph)
        print(report)
コード例 #3
0
def main(_):
    if FLAGS.metagraphdef:
        with gfile.GFile(FLAGS.metagraphdef) as meta_file:
            metagraph = meta_graph_pb2.MetaGraphDef()
            if FLAGS.metagraphdef.endswith(".pbtxt"):
                text_format.Merge(meta_file.read(), metagraph)
            else:
                metagraph.ParseFromString(meta_file.read())
        if FLAGS.fetch is not None:
            fetch_collection = meta_graph_pb2.CollectionDef()
            fetch_collection.node_list.value.append(FLAGS.fetch)
            metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
    else:
        with gfile.GFile(FLAGS.graphdef) as graph_file:
            graph_def = graph_pb2.GraphDef()
            if FLAGS.graphdef.endswith(".pbtxt"):
                text_format.Merge(graph_file.read(), graph_def)
            else:
                graph_def.ParseFromString(graph_file.read())
            importer.import_graph_def(graph_def, name="")
            graph = ops.get_default_graph()
            fetch = graph.get_operation_by_name(FLAGS.fetch)
            graph.add_to_collection("train_op", fetch)
            metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(),
                                                graph=graph)

    rewriter_config = rewriter_config_pb2.RewriterConfig()
    if FLAGS.rewriter_config is not None:
        text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

    report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
    print(report)
    report = cost_analyzer.GenerateMemoryReport(metagraph)
    print(report)