def _get_logged_ops(graph, run_meta=None, add_trace=True, add_trainable_var=True): """Extract trainable model parameters and FLOPs for ops from a Graph. Args: graph: tf.Graph. run_meta: RunMetadata proto used to complete shape information. add_trace: Whether to add op trace information. add_trainable_var: Whether to assign tf.trainable_variables() op type '_trainable_variables'. Returns: logged_ops: dict mapping from op_name to OpLogEntry. """ if run_meta: graph = _fill_missing_graph_shape(graph, run_meta) op_missing_shape = 0 logged_ops = {} # TODO(xpan): Work with Profiler more efficiently. for op in graph.get_operations(): try: stats = ops.get_stats_for_node_def(graph, op.node_def, REGISTERED_FLOP_STATS) except ValueError: # Catch Exception When shape is incomplete. Skip it. op_missing_shape += 1 stats = None entry = tfprof_log_pb2.OpLogEntry() entry.name = op.name add_entry = False if stats and stats.value: entry.float_ops = int(stats.value) add_entry = True if add_trace: for tb in op.traceback_with_start_lines: trace = entry.code_def.traces.add() trace.file = tb[0] if tb[0] else 'none' trace.lineno = tb[1] if tb[1] else -1 trace.function = tb[2] if tb[2] else 'none' trace.line = tb[3] if tb[3] else 'none' trace.func_start_line = tb[4] if tb[4] else -1 add_entry = True if add_entry: logged_ops[entry.name] = entry if add_trainable_var: for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): if v.op.name not in logged_ops: entry = tfprof_log_pb2.OpLogEntry() entry.name = v.op.name entry.types.append(TRAINABLE_VARIABLES) logged_ops[entry.name] = entry else: logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) if op_missing_shape > 0 and not run_meta: sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % op_missing_shape) return logged_ops
def _get_logged_ops(graph, run_meta=None, add_trace=True, add_trainable_var=True): """Extract trainable model parameters and FLOPs for ops from a Graph. Args: graph: tf.Graph. run_meta: RunMetadata proto used to complete shape information. add_trace: Whether to add op trace information. add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op type '_trainable_variables'. Returns: logged_ops: dict mapping from op_name to OpLogEntry. string_to_id: dict mapping from string to id. """ if run_meta: graph = _fill_missing_graph_shape(graph, run_meta) op_missing_shape = 0 logged_ops = {} string_to_id = {} string_to_id['none'] = len(string_to_id) # TODO(xpan): Work with Profiler more efficiently. for op in graph.get_operations(): try: stats = ops.get_stats_for_node_def(graph, op.node_def, REGISTERED_FLOP_STATS) except ValueError: # Catch Exception When shape is incomplete. Skip it. op_missing_shape += 1 stats = None entry = tfprof_log_pb2.OpLogEntry() entry.name = op.name add_entry = False if stats and stats.value: entry.float_ops = int(stats.value) add_entry = True if add_trace: if op.traceback: for filename, lineno, funcname, line in op.traceback: trace = entry.code_def.traces.add() trace.file_id = _str_id(filename, string_to_id) if filename else 0 trace.lineno = lineno if lineno else -1 trace.function_id = _str_id( funcname, string_to_id) if funcname else 0 trace.line_id = _str_id(line, string_to_id) if line else 0 # TODO(slebedev): remove this unused field from the proto. trace.func_start_line = -1 add_entry = True if add_entry: logged_ops[entry.name] = entry if add_trainable_var: for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): if v.op.name not in logged_ops: entry = tfprof_log_pb2.OpLogEntry() entry.name = v.op.name entry.types.append(TRAINABLE_VARIABLES) logged_ops[entry.name] = entry else: logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) if op_missing_shape > 0 and not run_meta: sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % op_missing_shape) return logged_ops, string_to_id