def _operators_to_graph_def(shapes, ops, colon_replacement='$', with_ssa=True, with_gradient_scope=True, blob_name_tracker=None, show_simplified=False, custom_rename=None): ''' Main function to convert set of operators to a graph. Args: shapes: Dictionary mapping blob names to their shapes/dimensions. ops: List of Caffe2 operators, representing some computation graph ### **kwargs (model_to_graph_def, nets_to_graph_def, protos_to_graph_def) ### colon_replacement: Symbol to replace ':' with. ':i' in TF has a special meaning, so we need to replace it with a non-conflicting symbol. with_ssa: Boolean with_gradient_scope: Boolean blob_name_tracker: Dictionary tracking names of blobs (inputs/outputs from operators) show_simplified: Whether to show a simplified version of the model graph Sets all of the following values: clear_debug_info: Boolean representing whether to silence debug info (which can be very verbose) show_forward_only: Boolean representing whether to only show blobs involved in the forward pass show_cpu_only: Boolean representing whether to only show blobs that are not associated with a gpu use_tensorflow_naming: Boolean representing whether to convert some common Caffe2 naming conventions to their Tensorflow counterparts custom_rename: Function string -> string that defines a custom renaming function to use. Returns: current_graph: GraphDef representing the computation graph formed by the set of operators. ''' if blob_name_tracker is not None: blob_name_tracker.clear() else: blob_name_tracker = {} blob_name_tracker.update(_get_blob_names(ops)) _clear_debug_info(ops, show_simplified) # clear_debug_info ops = _filter_ops(ops, _check_if_forward, show_simplified) # show_forward_only ops = _filter_ops(ops, _check_if_cpu, show_simplified) # show_cpu_only if custom_rename: _rename_all(shapes, blob_name_tracker, ops, custom_rename) if colon_replacement: _replace_colons(shapes, blob_name_tracker, ops, colon_replacement) if with_ssa: _convert_to_ssa(shapes, blob_name_tracker, ops) if with_gradient_scope: _add_gradient_scope(shapes, blob_name_tracker, ops) _fill_missing_operator_names(ops) if show_simplified: # use_tensorflow_naming _rename_tensorflow_style(shapes, blob_name_tracker, ops) producing_ops = {} blobs = set() input_blobs, inter_blobs, _ = _compute_in_out(ops) current_graph = GraphDef() seen = set(input_blobs) for op in ops: nodes_from_op = _operator_to_node_simp(op, inter_blobs, seen) if \ show_simplified else \ [_operator_to_node(shapes, op)] # .extend() expects an iterable current_graph.node.extend(nodes_from_op) for input_blob in op.input: blobs.add(input_blob) for i, output_blob in enumerate(op.output): blobs.add(output_blob) producing_ops.setdefault(output_blob, []).append((op, i)) if show_simplified: # Show a cleaner, easier-to-interpret version of the model graph blobs = input_blobs for blob in blobs: current_graph.node.extend([_blob_to_node(producing_ops, {}, blob)]) return current_graph
def keras_model_to_graph_def(keras_layer): """Returns a GraphDef representation of the Keras model in a dict form. Note that it only supports models that implemented to_json(). Args: keras_layer: A dict from Keras model.to_json(). Returns: A GraphDef representation of the layers in the model. """ input_to_layer = {} model_name_to_output = {} g = GraphDef() # Sequential model layers do not have a field "inbound_nodes" but # instead are defined implicitly via order of layers. prev_node_name = None for (name_scope, layer) in _walk_layers(keras_layer): if _is_model(layer): ( input_to_layer, model_name_to_output, prev_node_name, ) = _update_dicts( name_scope, layer, input_to_layer, model_name_to_output, prev_node_name, ) continue layer_config = layer.get("config") node_name = _scoped_name(name_scope, layer_config.get("name")) node_def = g.node.add() node_def.name = node_name if layer.get("class_name") is not None: keras_cls_name = layer.get("class_name").encode("ascii") node_def.attr["keras_class"].s = keras_cls_name dtype_or_policy = layer_config.get("dtype") # Skip dtype processing if this is a dict, since it's presumably a instance of # tf/keras/mixed_precision/Policy rather than a single dtype. # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype. if dtype_or_policy is not None and not isinstance( dtype_or_policy, dict): tf_dtype = dtypes.as_dtype(layer_config.get("dtype")) node_def.attr["dtype"].type = tf_dtype.as_datatype_enum if layer.get("inbound_nodes") is not None: for maybe_inbound_node in layer.get("inbound_nodes"): inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node) for [name, size, index, _] in inbound_nodes: inbound_name = _scoped_name(name_scope, name) # An input to a layer can be output from a model. In that case, the name # of inbound_nodes to a layer is a name of a model. Remap the name of the # model to output layer of the model. Also, since there can be multiple # outputs in a model, make sure we pick the right output_layer from the model. inbound_node_names = model_name_to_output.get( inbound_name, [inbound_name]) node_def.input.append(inbound_node_names[index]) elif prev_node_name is not None: node_def.input.append(prev_node_name) if node_name in input_to_layer: node_def.input.append(input_to_layer.get(node_name)) prev_node_name = node_def.name return g
def test_combine_graph_defs_function_collison(self): graph_def_a = GraphDef() text_format.Merge( ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''', graph_def_a) graph_def_b = GraphDef() text_format.Merge( ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "div" op: "Div" input: "x" input: "y" } } function { signature { name: "foo_1" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''', graph_def_b) with six.assertRaisesRegex( self, ValueError, ('Cannot combine GraphDefs because functions share a name but ' 'are different: foo')): graph_util.combine_graph_defs(graph_def_a, graph_def_b)
def test_combine_graph_defs_src_function_duplicate_keys(self): graph_def_a = GraphDef() text_format.Merge( ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''', graph_def_a) graph_def_b = GraphDef() text_format.Merge( ''' library { function { signature { name: "bar" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } } function { signature { name: "bar" input_arg { name: "y" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } } } ''', graph_def_b) with six.assertRaisesRegex( self, ValueError, 'A GraphDef contains non-unique function names: bar'): graph_util.combine_graph_defs(graph_def_a, graph_def_b)
def test_combine_graph_defs(self): expected_proto = ''' node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } node { name: "A" op: "Input" } node { name: "B" op: "Input" } node { name: "C" op: "MatMul" input: "A" input: "B" } versions { producer: 21 } ''' graph_def_a = GraphDef() text_format.Merge( ''' node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } versions { producer: 21 } ''', graph_def_a) graph_def_b = GraphDef() text_format.Merge( ''' node { name: "A" op: "Input" } node { name: "B" op: "Input" } node { name: "C" op: "MatMul" input: "A" input: "B" } versions { producer: 21 } ''', graph_def_b) self.assertProtoEquals( expected_proto, graph_util.combine_graph_defs(graph_def_a, graph_def_b))
def test_combine_graph_defs_function(self): expected_proto = ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } function { signature { name: "foo_1" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''' graph_def_a = GraphDef() text_format.Merge( ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''', graph_def_a) graph_def_b = GraphDef() text_format.Merge( ''' library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } function { signature { name: "foo_1" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } ''', graph_def_b) self.assertProtoEquals( expected_proto, graph_util.combine_graph_defs(graph_def_a, graph_def_b))
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. Args: model (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'. Defaults to 'ONNX' format because it outputs the most visually understandable format. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ operator_export_type = getattr(OperatorExportTypes, operator_export_type) # This code is similar to torch/onnx/utils.py, but adjusted to provide # the most visually understandable output. # # For example, the commented out line # # # torch._C._jit_pass_onnx_peephole(graph). # # This pass removes a lot of scope information. The amount of optimization # cannot be too much (lots of information lost) or too little (too much # useless information), therefore I copy-pasted the code so that it will # not be affected by torch/onnx/utils.py changes. def _optimize_trace(trace, operator_export_type): trace.set_graph(_optimize_graph(trace.graph(), operator_export_type)) def _optimize_graph(graph, operator_export_type): # torch._C._jit_pass_remove_inplace_ops(graph) # we record now record some ops like ones/zeros # into a trace where we previously recorded constants # use constant prop to maintain our current level of onnx support # without implementing symbolics for all of them torch._C._jit_pass_constant_propagation(graph) torch.onnx.utils._split_tensor_list_constants(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) # torch._C._jit_pass_canonicalize_ops(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lint(graph) # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 torch._C._jit_pass_prepare_division_for_onnx(graph) # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) # onnx does not support tuples, so try to remove them torch._C._jit_pass_lower_all_tuples(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lint(graph) if operator_export_type != OperatorExportTypes.RAW: graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) # torch._C._jit_pass_onnx_peephole(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_fixup_onnx_loops(graph) torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) return graph with torch.onnx.set_training(model, False): try: trace, _ = torch.jit.get_trace_graph(model, args) except RuntimeError: print('Error occurs, No graph saved') _ = model(*args) # don't catch, just print the error message print("Checking if it's onnx problem...") try: import tempfile torch.onnx.export(model, args, tempfile.TemporaryFile(), verbose=True) except RuntimeError: print("Your model fails onnx too, please report to onnx team") # Create an object matching # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto # The producer version has been reverse engineered from standard # TensorBoard logged data. return GraphDef(versions=VersionDef(producer=22)) try: # An optimized graph helps debug at a higher level. Users can focus # on connections between big modules such as Linear instead of W, x, # bias, matmul, etc. Honestly, most users don't care about those # detailed nodes information. _optimize_trace(trace, operator_export_type) except RuntimeError as e: # Optimize trace might fail (due to bad scopes in some cases we've seen) # and we don't want graph visualization to fail in this case. In this # case we'll log the warning and display the non-optimized graph. logging.warn(ImportError(e)) graph = trace.graph() if verbose: print(graph) list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes) # We are hardcoding that this was run on CPU even though it might have actually # run on GPU. Note this is what is shown in TensorBoard and has no bearing # on actual execution. # TODO: See if we can extract GPU vs CPU information from the PyTorch model # and pass it correctly to TensorBoard. # # Definition of StepStats and DeviceStepStats can be found at # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts # and # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto stepstats = RunMetadata(step_stats=StepStats(dev_stats=[ DeviceStepStats(device="/device:CPU:0", node_stats=node_stats) ])) return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
def test_combine_graph_defs_name_collided_but_same_content(self): expected_proto = """ node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } node { name: "A" op: "Input" } versions { producer: 21 } """ graph_def_a = GraphDef() text_format.Merge( """ node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } versions { producer: 21 } """, graph_def_a, ) graph_def_b = GraphDef() text_format.Merge( """ node { name: "X" op: "Input" } node { name: "A" op: "Input" } versions { producer: 21 } """, graph_def_b, ) self.assertProtoEquals( expected_proto, graph_util.combine_graph_defs(graph_def_a, graph_def_b), )
def test_merge_graph_defs_function(self): expected_proto = """ library { function { signature { name: "graph_1_foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } function { signature { name: "graph_2_foo" input_arg { name: "x" type: DT_INT32 } output_arg { name: "identity" type: DT_INT32 } } node_def { name: "add" op: "Add" input: "x" input: "y" } } function { signature { name: "graph_2_foo_1" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } """ graph_def_a = GraphDef() text_format.Parse( """ library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } """, graph_def_a, ) graph_def_b = GraphDef() text_format.Parse( """ library { function { signature { name: "foo" input_arg { name: "x" type: DT_INT32 } output_arg { name: "identity" type: DT_INT32 } } node_def { name: "add" op: "Add" input: "x" input: "y" } } function { signature { name: "foo_1" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } node_def { name: "add" op: "Add" input: "x" input: "y" } } } """, graph_def_b, ) self.assertProtoEquals( expected_proto, graph_util.merge_graph_defs([graph_def_a, graph_def_b]), )
def test_merge_graph_defs_partitioned_call_remap(self): expected_proto = GraphDef() text_format.Parse( """ node { name: "graph_1/X" op: "PartitionedCall" attr { key: "f" value { func { name: "graph_1_foo" } } } } library { function { signature { name: "graph_1_foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } } } """, expected_proto, ) graph_def_a = GraphDef() text_format.Parse( """ node { name: "X" op: "PartitionedCall" attr { key: "f" value { func { name: "foo" } } } } library { function { signature { name: "foo" input_arg { name: "x" type: DT_HALF } output_arg { name: "identity" type: DT_HALF } } } } """, graph_def_a, ) graph_def_b = GraphDef() self.assertProtoEquals( expected_proto, graph_util.merge_graph_defs([graph_def_a, graph_def_b]), )
def test_merge_graph_defs(self): expected_proto = """ node { name: "graph_1/X" op: "Input" } node { name: "graph_1/W" op: "Input" } node { name: "graph_1/Y" op: "MatMul" input: "graph_1/X" input: "graph_1/W" } node { name: "graph_2/A" op: "Input" } node { name: "graph_2/B" op: "Input" } node { name: "graph_2/C" op: "MatMul" input: "graph_2/A" input: "graph_2/B" } node { name: "graph_3/A" op: "Input" } node { name: "graph_3/B" op: "Input" } versions { producer: 21 } """ graph_def_a = GraphDef() text_format.Parse( """ node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } versions { producer: 21 } """, graph_def_a, ) graph_def_b = GraphDef() text_format.Parse( """ node { name: "A" op: "Input" } node { name: "B" op: "Input" } node { name: "C" op: "MatMul" input: "A" input: "B" } versions { producer: 21 } """, graph_def_b, ) graph_def_c = GraphDef() text_format.Parse( """ node { name: "A" op: "Input" } node { name: "B" op: "Input" } versions { producer: 21 } """, graph_def_c, ) self.assertProtoEquals( expected_proto, graph_util.merge_graph_defs( [graph_def_a, graph_def_b, graph_def_c] ), )
def test_merge_graph_defs_name_collided_with_same_content(self): expected_proto = """ node { name: "graph_1/X" op: "Input" } node { name: "graph_1/W" op: "Input" } node { name: "graph_1/Y" op: "MatMul" input: "graph_1/X" input: "graph_1/W" } node { name: "graph_2/X" op: "Input" } node { name: "graph_2/A" op: "Input" } node { name: "graph_2/Y" op: "MatMul" input: "graph_2/X" input: "graph_2/A" } versions { producer: 21 } """ graph_def_a = GraphDef() text_format.Parse( """ node { name: "X" op: "Input" } node { name: "W" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "W" } versions { producer: 21 } """, graph_def_a, ) graph_def_b = GraphDef() text_format.Parse( """ node { name: "X" op: "Input" } node { name: "A" op: "Input" } node { name: "Y" op: "MatMul" input: "X" input: "A" } versions { producer: 21 } """, graph_def_b, ) self.assertProtoEquals( expected_proto, graph_util.merge_graph_defs([graph_def_a, graph_def_b]), )
def keras_model_to_graph_def(keras_layer): """Returns a GraphDef representation of the Keras model in a dict form. Note that it only supports models that implemented to_json(). Args: keras_layer: A dict from Keras model.to_json(). Returns: A GraphDef representation of the layers in the model. """ input_to_layer = {} model_name_to_output = {} g = GraphDef() # Sequential model layers do not have a field "inbound_nodes" but # instead are defined implicitly via order of layers. prev_node_name = None for (name_scope, layer) in _walk_layers(keras_layer): if _is_model(layer): (input_to_layer, model_name_to_output, prev_node_name) = _update_dicts(name_scope, layer, input_to_layer, model_name_to_output, prev_node_name) continue layer_config = layer.get('config') node_name = _scoped_name(name_scope, layer_config.get('name')) node_def = g.node.add() node_def.name = node_name if layer.get('class_name') is not None: keras_cls_name = layer.get('class_name').encode('ascii') node_def.attr['keras_class'].s = keras_cls_name if layer_config.get('dtype') is not None: tf_dtype = dtypes.as_dtype(layer_config.get('dtype')) node_def.attr['dtype'].type = tf_dtype.as_datatype_enum if layer.get('inbound_nodes') is not None: for maybe_inbound_node in layer.get('inbound_nodes'): inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node) for [name, size, index, _] in inbound_nodes: inbound_name = _scoped_name(name_scope, name) # An input to a layer can be output from a model. In that case, the name # of inbound_nodes to a layer is a name of a model. Remap the name of the # model to output layer of the model. Also, since there can be multiple # outputs in a model, make sure we pick the right output_layer from the model. inbound_node_names = model_name_to_output.get( inbound_name, [inbound_name]) node_def.input.append(inbound_node_names[index]) elif prev_node_name is not None: node_def.input.append(prev_node_name) if node_name in input_to_layer: node_def.input.append(input_to_layer.get(node_name)) prev_node_name = node_def.name return g
def add_graph(self, model, *args, **kargs): visitor = GraphVisitor(model, *args, **kargs) stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])) graph = GraphDef(node=visitor._graph, versions=VersionDef(producer=22)) self._get_file_writer().add_graph((graph, stepstats))
def visualize( model_path: str, log_path: str, input: np.ndarray = None, inp_dict: dict = None, cal_params: bool = True, cal_flops: bool = True, cal_activations: bool = True, logging_to_stdout: bool = True, bar_length_max: int = 20, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. :param cal_params: whether calculate and record params size. :param cal_flops: whether calculate and record op flops. :param cal_activations: whether calculate and record op activations. :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return enable_receptive_field() graph = Network.load(model_path) graph.reset_batch_size(1) has_input = False if input is not None or inp_dict is not None: has_input = True repl_dict = {} inp_vars = graph.input_vars if inp_dict is not None: assert len(inp_dict) == len( inp_vars ), "Inputs are not sufficient for calculation." for v in inp_vars: new_input = graph.make_const(inp_dict[v.name], name=v.name) repl_dict[v] = new_input else: assert len(inp_vars) == 1, "The graph needs more than one input." inp_var = inp_vars[0] repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) graph.replace_vars(repl_dict=repl_dict) graph._compile() def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] activations_list = [] total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr.".format( node ) ) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue( list=AttrValue.ListValue( shape=[ TensorShapeProto( dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ] ) ] ) ), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } if cal_flops: flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") ) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if cal_activations: acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts["name"] = node.name acts["class_name"] = node.type activations_list.append(acts) if cal_params: if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") ) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, ) ) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } ( total_flops, total_param_dims, total_param_size, total_act_dims, total_act_size, ) = (0, 0, 0, 0, 0) if cal_params: total_param_dims, total_param_size, params_list = sum_param_stats( params_list, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") extra_info["total_param_size"] = sizeof_fmt(total_param_size) if logging_to_stdout: print_param_stats(params_list) if cal_flops: total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if logging_to_stdout: print_op_stats(flops_list) if cal_activations: total_act_dims, total_act_size, activations_list = sum_activations_stats( activations_list, bar_length_max ) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_size"] = sizeof_fmt(total_act_size) if logging_to_stdout: print_activations_stats(activations_list, has_input=has_input) if cal_flops and cal_params: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata( step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) ) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) return ( total_stats( param_size=total_param_size, flops=total_flops, act_size=total_act_size, ), stats_details( params=params_list, flops=flops_list, activations=activations_list ), )
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. Args: model (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'. Defaults to 'ONNX' format because it outputs the most visually understandable format. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ operator_export_type = getattr(OperatorExportTypes, operator_export_type) with torch.onnx.set_training(model, False): try: trace, _ = torch.jit.get_trace_graph(model, args) except RuntimeError: print('Error occurs, No graph saved') _ = model(*args) # don't catch, just print the error message print("Checking if it's onnx problem...") try: import tempfile torch.onnx.export( model, args, tempfile.TemporaryFile(), verbose=True) except RuntimeError: print("Your model cannot be exported by onnx, please report to onnx team") # Create an object matching # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto # The producer version has been reverse engineered from standard # TensorBoard logged data. return GraphDef(versions=VersionDef(producer=22)) try: # An optimized graph helps debug at a higher level. Users can focus # on connections between big modules such as Linear instead of W, x, # bias, matmul, etc. Honestly, most users don't care about those # detailed nodes information. _optimize_trace(trace, operator_export_type) except RuntimeError as e: # Optimize trace might fail (due to bad scopes in some cases we've seen) # and we don't want graph visualization to fail in this case. In this # case we'll log the warning and display the non-optimized graph. logging.warn(ImportError(e)) graph = trace.graph() if verbose: print(graph) list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes) # We are hardcoding that this was run on CPU even though it might have actually # run on GPU. Note this is what is shown in TensorBoard and has no bearing # on actual execution. # TODO: See if we can extract GPU vs CPU information from the PyTorch model # and pass it correctly to TensorBoard. # # Definition of StepStats and DeviceStepStats can be found at # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts # and # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0", node_stats=node_stats)])) return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
def visualize( model_path: str, log_path: str, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) enable_receptive_field() graph = Network.load(model_path) def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] for node in graph.all_oprs: if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr." .format(node)) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue(list=AttrValue.ListValue(shape=[ TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ]) ])), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode( encoding="utf-8")) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, )) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_param_stats( params_list, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device=device)])) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) return total_param_size, total_flops