def merge_default_with_oplog(graph, op_log=None, run_meta=None, add_trace=True, add_trainable_var=True): """Merge the tfprof default extra info with caller's op_log. Args: graph: tf.Graph. If None and eager execution is not enabled, use default graph. op_log: OpLogProto proto. run_meta: RunMetadata proto used to complete shape information. add_trace: Whether to add op trace information. add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op type '_trainable_variables'. Returns: tmp_op_log: Merged OpLogProto proto. """ if not graph and not context.executing_eagerly(): graph = ops.get_default_graph() tmp_op_log = tfprof_log_pb2.OpLogProto() if not graph: return tmp_op_log logged_ops, string_to_id = _get_logged_ops( graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) if not op_log: tmp_op_log.log_entries.extend(logged_ops.values()) else: all_ops = {} for entry in op_log.log_entries: all_ops[entry.name] = entry for op_name, entry in six.iteritems(logged_ops): if op_name in all_ops: all_ops[op_name].types.extend(entry.types) if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: all_ops[op_name].float_ops = entry.float_ops if entry.code_def.traces and not all_ops[ op_name].code_def.traces: all_ops[op_name].code_def.MergeFrom(entry.code_def) else: all_ops[op_name] = entry tmp_op_log.log_entries.extend(all_ops.values()) for s, i in six.iteritems(string_to_id): tmp_op_log.id_to_string[i] = s return tmp_op_log
def _merge_default_with_oplog(graph, op_log=None, run_meta=None, add_trace=True, add_trainable_var=True): """Merge the tfprof default extra info with caller's op_log. Args: graph: tf.Graph. op_log: OpLogProto proto. run_meta: RunMetadata proto used to complete shape information. add_trace: Whether to add op trace information. add_trainable_var: Whether to assign tf.trainable_variables() op type '_trainable_variables'. Returns: tmp_op_log: Merged OpLogProto proto. """ tmp_op_log = tfprof_log_pb2.OpLogProto() logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) if not op_log: tmp_op_log.log_entries.extend(logged_ops.values()) else: all_ops = dict() for entry in op_log.log_entries: all_ops[entry.name] = entry for op_name, entry in six.iteritems(logged_ops): if op_name in all_ops: all_ops[op_name].types.extend(entry.types) if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: all_ops[op_name].float_ops = entry.float_ops if entry.code_def.traces and not all_ops[ op_name].code_def.traces: all_ops[op_name].code_def.MergeFrom(entry.code_def) else: all_ops[op_name] = entry tmp_op_log.log_entries.extend(all_ops.values()) return tmp_op_log