def _get_logged_ops(graph,
                    run_meta=None,
                    add_trace=True,
                    add_trainable_var=True):
    """Extract trainable model parameters and FLOPs for ops from a Graph.

  Args:
    graph: tf.Graph.
    run_meta: RunMetadata proto used to complete shape information.
    add_trace: Whether to add op trace information.
    add_trainable_var: Whether to assign tf.trainable_variables() op type
      '_trainable_variables'.
  Returns:
    logged_ops: dict mapping from op_name to OpLogEntry.
  """
    if run_meta:
        graph = _fill_missing_graph_shape(graph, run_meta)

    op_missing_shape = 0
    logged_ops = {}
    # TODO(xpan): Work with Profiler more efficiently.
    for op in graph.get_operations():
        try:
            stats = ops.get_stats_for_node_def(graph, op.node_def,
                                               REGISTERED_FLOP_STATS)
        except ValueError:
            # Catch Exception When shape is incomplete. Skip it.
            op_missing_shape += 1
            stats = None

        entry = tfprof_log_pb2.OpLogEntry()
        entry.name = op.name
        add_entry = False
        if stats and stats.value:
            entry.float_ops = int(stats.value)
            add_entry = True

        if add_trace:
            for tb in op.traceback_with_start_lines:
                trace = entry.code_def.traces.add()
                trace.file = tb[0] if tb[0] else 'none'
                trace.lineno = tb[1] if tb[1] else -1
                trace.function = tb[2] if tb[2] else 'none'
                trace.line = tb[3] if tb[3] else 'none'
                trace.func_start_line = tb[4] if tb[4] else -1
            add_entry = True

        if add_entry:
            logged_ops[entry.name] = entry

    if add_trainable_var:
        for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
            if v.op.name not in logged_ops:
                entry = tfprof_log_pb2.OpLogEntry()
                entry.name = v.op.name
                entry.types.append(TRAINABLE_VARIABLES)
                logged_ops[entry.name] = entry
            else:
                logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)

    if op_missing_shape > 0 and not run_meta:
        sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
                         op_missing_shape)
    return logged_ops
示例#2
0
def _get_logged_ops(graph,
                    run_meta=None,
                    add_trace=True,
                    add_trainable_var=True):
    """Extract trainable model parameters and FLOPs for ops from a Graph.

  Args:
    graph: tf.Graph.
    run_meta: RunMetadata proto used to complete shape information.
    add_trace: Whether to add op trace information.
    add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
      type '_trainable_variables'.
  Returns:
    logged_ops: dict mapping from op_name to OpLogEntry.
    string_to_id: dict mapping from string to id.
  """
    if run_meta:
        graph = _fill_missing_graph_shape(graph, run_meta)

    op_missing_shape = 0
    logged_ops = {}
    string_to_id = {}
    string_to_id['none'] = len(string_to_id)
    # TODO(xpan): Work with Profiler more efficiently.
    for op in graph.get_operations():
        try:
            stats = ops.get_stats_for_node_def(graph, op.node_def,
                                               REGISTERED_FLOP_STATS)
        except ValueError:
            # Catch Exception When shape is incomplete. Skip it.
            op_missing_shape += 1
            stats = None

        entry = tfprof_log_pb2.OpLogEntry()
        entry.name = op.name
        add_entry = False
        if stats and stats.value:
            entry.float_ops = int(stats.value)
            add_entry = True

        if add_trace:
            if op.traceback:
                for filename, lineno, funcname, line in op.traceback:
                    trace = entry.code_def.traces.add()
                    trace.file_id = _str_id(filename,
                                            string_to_id) if filename else 0
                    trace.lineno = lineno if lineno else -1
                    trace.function_id = _str_id(
                        funcname, string_to_id) if funcname else 0
                    trace.line_id = _str_id(line, string_to_id) if line else 0
                    # TODO(slebedev): remove this unused field from the proto.
                    trace.func_start_line = -1
            add_entry = True

        if add_entry:
            logged_ops[entry.name] = entry

    if add_trainable_var:
        for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
            if v.op.name not in logged_ops:
                entry = tfprof_log_pb2.OpLogEntry()
                entry.name = v.op.name
                entry.types.append(TRAINABLE_VARIABLES)
                logged_ops[entry.name] = entry
            else:
                logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)

    if op_missing_shape > 0 and not run_meta:
        sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
                         op_missing_shape)
    return logged_ops, string_to_id