def __init__(self, model, dummy_input, verbose=False): super().__init__(model, dummy_input) from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats from tensorboard.compat.proto.versions_pb2 import VersionDef list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) if verbose: print(self.trace.graph) self.stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device="/device:CPU:0")])) self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22))
def graph(model, args, verbose=False, use_strict_trace=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. Args: model (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. use_strict_trace (bool): Whether to pass keyword argument `strict` to `torch.jit.trace`. Pass False when you want the tracer to record your mutable container types (list, dict) """ with torch.onnx.select_model_mode_for_export( model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx? try: trace = torch.jit.trace(model, args, strict=use_strict_trace) graph = trace.graph torch._C._jit_pass_inline(graph) except RuntimeError as e: print(e) print('Error occurs, No graph saved') raise e if verbose: print(graph) list_of_nodes = parse(graph, trace, args) # We are hardcoding that this was run on CPU even though it might have actually # run on GPU. Note this is what is shown in TensorBoard and has no bearing # on actual execution. # TODO: See if we can extract GPU vs CPU information from the PyTorch model # and pass it correctly to TensorBoard. # # Definition of StepStats and DeviceStepStats can be found at # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts # and # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device="/device:CPU:0")])) return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. Args: model (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'. Defaults to 'ONNX' format because it outputs the most visually understandable format. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ operator_export_type = getattr(OperatorExportTypes, operator_export_type) with torch.onnx.set_training(model, False): try: trace, _ = torch.jit.get_trace_graph(model, args) except RuntimeError: print('Error occurs, No graph saved') _ = model(*args) # don't catch, just print the error message print("Checking if it's onnx problem...") try: import tempfile torch.onnx.export( model, args, tempfile.TemporaryFile(), verbose=True) except RuntimeError: print("Your model cannot be exported by onnx, please report to onnx team") # Create an object matching # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto # The producer version has been reverse engineered from standard # TensorBoard logged data. return GraphDef(versions=VersionDef(producer=22)) try: # An optimized graph helps debug at a higher level. Users can focus # on connections between big modules such as Linear instead of W, x, # bias, matmul, etc. Honestly, most users don't care about those # detailed nodes information. _optimize_trace(trace, operator_export_type) except RuntimeError as e: # Optimize trace might fail (due to bad scopes in some cases we've seen) # and we don't want graph visualization to fail in this case. In this # case we'll log the warning and display the non-optimized graph. logging.warn(ImportError(e)) graph = trace.graph() if verbose: print(graph) list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes) # We are hardcoding that this was run on CPU even though it might have actually # run on GPU. Note this is what is shown in TensorBoard and has no bearing # on actual execution. # TODO: See if we can extract GPU vs CPU information from the PyTorch model # and pass it correctly to TensorBoard. # # Definition of StepStats and DeviceStepStats can be found at # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts # and # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0", node_stats=node_stats)])) return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
def add_graph(self, model, *args, **kargs): visitor = GraphVisitor(model, *args, **kargs) stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])) graph = GraphDef(node=visitor._graph, versions=VersionDef(producer=22)) self._get_file_writer().add_graph((graph, stepstats))
def visualize( model_path: str, log_path: str, input: np.ndarray = None, inp_dict: dict = None, cal_params: bool = True, cal_flops: bool = True, cal_activations: bool = True, logging_to_stdout: bool = True, bar_length_max: int = 20, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. :param cal_params: whether calculate and record params size. :param cal_flops: whether calculate and record op flops. :param cal_activations: whether calculate and record op activations. :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return enable_receptive_field() graph = Network.load(model_path) graph.reset_batch_size(1) has_input = False if input is not None or inp_dict is not None: has_input = True repl_dict = {} inp_vars = graph.input_vars if inp_dict is not None: assert len(inp_dict) == len( inp_vars ), "Inputs are not sufficient for calculation." for v in inp_vars: new_input = graph.make_const(inp_dict[v.name], name=v.name) repl_dict[v] = new_input else: assert len(inp_vars) == 1, "The graph needs more than one input." inp_var = inp_vars[0] repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) graph.replace_vars(repl_dict=repl_dict) graph._compile() def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] activations_list = [] total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr.".format( node ) ) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue( list=AttrValue.ListValue( shape=[ TensorShapeProto( dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ] ) ] ) ), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } if cal_flops: flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") ) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if cal_activations: acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts["name"] = node.name acts["class_name"] = node.type activations_list.append(acts) if cal_params: if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") ) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, ) ) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } ( total_flops, total_param_dims, total_param_size, total_act_dims, total_act_size, ) = (0, 0, 0, 0, 0) if cal_params: total_param_dims, total_param_size, params_list = sum_param_stats( params_list, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") extra_info["total_param_size"] = sizeof_fmt(total_param_size) if logging_to_stdout: print_param_stats(params_list) if cal_flops: total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if logging_to_stdout: print_op_stats(flops_list) if cal_activations: total_act_dims, total_act_size, activations_list = sum_activations_stats( activations_list, bar_length_max ) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_size"] = sizeof_fmt(total_act_size) if logging_to_stdout: print_activations_stats(activations_list, has_input=has_input) if cal_flops and cal_params: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata( step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) ) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) return ( total_stats( param_size=total_param_size, flops=total_flops, act_size=total_act_size, ), stats_details( params=params_list, flops=flops_list, activations=activations_list ), )
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. Args: model (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'. Defaults to 'ONNX' format because it outputs the most visually understandable format. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ operator_export_type = getattr(OperatorExportTypes, operator_export_type) # This code is similar to torch/onnx/utils.py, but adjusted to provide # the most visually understandable output. # # For example, the commented out line # # # torch._C._jit_pass_onnx_peephole(graph). # # This pass removes a lot of scope information. The amount of optimization # cannot be too much (lots of information lost) or too little (too much # useless information), therefore I copy-pasted the code so that it will # not be affected by torch/onnx/utils.py changes. def _optimize_trace(trace, operator_export_type): trace.set_graph(_optimize_graph(trace.graph(), operator_export_type)) def _optimize_graph(graph, operator_export_type): # torch._C._jit_pass_remove_inplace_ops(graph) # we record now record some ops like ones/zeros # into a trace where we previously recorded constants # use constant prop to maintain our current level of onnx support # without implementing symbolics for all of them torch._C._jit_pass_constant_propagation(graph) torch.onnx.utils._split_tensor_list_constants(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) # torch._C._jit_pass_canonicalize_ops(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lint(graph) # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 torch._C._jit_pass_prepare_division_for_onnx(graph) # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) # onnx does not support tuples, so try to remove them torch._C._jit_pass_lower_all_tuples(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lint(graph) if operator_export_type != OperatorExportTypes.RAW: graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) # torch._C._jit_pass_onnx_peephole(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_fixup_onnx_loops(graph) torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) return graph with torch.onnx.set_training(model, False): try: trace, _ = torch.jit.get_trace_graph(model, args) except RuntimeError: print('Error occurs, No graph saved') _ = model(*args) # don't catch, just print the error message print("Checking if it's onnx problem...") try: import tempfile torch.onnx.export(model, args, tempfile.TemporaryFile(), verbose=True) except RuntimeError: print("Your model fails onnx too, please report to onnx team") # Create an object matching # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto # The producer version has been reverse engineered from standard # TensorBoard logged data. return GraphDef(versions=VersionDef(producer=22)) try: # An optimized graph helps debug at a higher level. Users can focus # on connections between big modules such as Linear instead of W, x, # bias, matmul, etc. Honestly, most users don't care about those # detailed nodes information. _optimize_trace(trace, operator_export_type) except RuntimeError as e: # Optimize trace might fail (due to bad scopes in some cases we've seen) # and we don't want graph visualization to fail in this case. In this # case we'll log the warning and display the non-optimized graph. logging.warn(ImportError(e)) graph = trace.graph() if verbose: print(graph) list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes) # We are hardcoding that this was run on CPU even though it might have actually # run on GPU. Note this is what is shown in TensorBoard and has no bearing # on actual execution. # TODO: See if we can extract GPU vs CPU information from the PyTorch model # and pass it correctly to TensorBoard. # # Definition of StepStats and DeviceStepStats can be found at # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts # and # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto stepstats = RunMetadata(step_stats=StepStats(dev_stats=[ DeviceStepStats(device="/device:CPU:0", node_stats=node_stats) ])) return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
def visualize( model_path: str, log_path: str, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) enable_receptive_field() graph = Network.load(model_path) def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] for node in graph.all_oprs: if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr." .format(node)) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue(list=AttrValue.ListValue(shape=[ TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ]) ])), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode( encoding="utf-8")) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, )) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_param_stats( params_list, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device=device)])) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) return total_param_size, total_flops