Ejemplo n.º 1
0
def print_model_analysis(graph,
                         run_meta=None,
                         op_log=None,
                         tfprof_cmd='scope',
                         tfprof_options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS):
    """Print model statistics.

    See go/tfprof or README for examples and tutorials.
    Run tfprof tool for help:
    'bazel run third_party/tensorflow/tools/tfprof help'

  Args:
    graph: tf.Graph.
    run_meta: tensorflow::RunMetadata proto. When provided, also shows valid
              timing and memory information when 'select' option contains
              'micros' and 'bytes'.
    op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
            group together ops and use a op_type to select the group.
    tfprof_cmd: string. Either 'op', 'scope', 'graph', 'code'.
                'op' view organize outputs using operation type. (e.g. MatMul)
                'scope' view organize outputs using graph node name scope.
                'graph' view organize outputs using graph node inputs/outputs.
                'code' view organize outputs using Python call stack.
    tfprof_options: See 'tfprof help' for details.
  Returns:
    If tfprof_cmd is 'scope' or 'graph', returns TFGraphNodeProto proto.
    If tfprof_cmd is 'op' or 'code', returns TFMultiGraphNodeProto proto.
    Side effect: stdout/file/timeline.json depending on tfprof_options['output']
  """
    # pylint: disable=protected-access
    op_log = tfprof_logger._merge_default_with_oplog(
        graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
    # pylint: enable=protected-access

    opts = _build_options(tfprof_options)

    run_meta_str = run_meta.SerializeToString() if run_meta else b''

    if tfprof_cmd == 'code' or tfprof_cmd == 'op':
        tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
        tfprof_node.ParseFromString(
            print_mdl.PrintModelAnalysis(
                graph.as_graph_def(add_shapes=True).SerializeToString(),
                run_meta_str, op_log.SerializeToString(),
                tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
    elif tfprof_cmd == 'graph' or tfprof_cmd == 'scope':
        tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
        tfprof_node.ParseFromString(
            print_mdl.PrintModelAnalysis(
                graph.as_graph_def(add_shapes=True).SerializeToString(),
                run_meta_str, op_log.SerializeToString(),
                tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
    else:
        raise errors.InvalidArgumentError(
            None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd)

    return tfprof_node
Ejemplo n.º 2
0
    def profile_operations(self, options):
        """Profile the statistics of the Operation types (e.g. MatMul, Conv2D).

    Args:
      options: A dict of profiler options.
    Returns:
      a TFMultiGraphNodeProto that records the results.
    """
        opts = _build_options(options)
        tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
        tfprof_node.ParseFromString(
            print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
        return tfprof_node
Ejemplo n.º 3
0
  def profile_python_codes(self, options):
    """Profile the statistics of the Python codes.

      Hint: set options['show_name_regexes'] = ['.*my_code.py.*']

    Args:
      options: A dict of profiler options.
    Returns:
      a TFMultiGraphNodeProto that records the results.
    """
    opts = _build_options(options)
    tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
    tfprof_node.ParseFromString(
        print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
    return tfprof_node