def onnx2dotnb(model_onnx, width="100%", orientation="LR"): """ Converts an ONNX graph into dot then into :epkg:`RenderJsDot`. See :ref:`onnxsklearnconsortiumrst`. """ from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer pydot_graph = GetPydotGraph(model_onnx.graph, name=model_onnx.graph.name, rankdir=orientation, node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled")) dot = pydot_graph.to_string() return RenderJsDot(dot, width=width)
def to_dot( self, recursive=False, prefix='', # pylint: disable=R0914 add_rt_shapes=False, use_onnx=False, **params): """ Produces a :epkg:`DOT` language string for the graph. :param params: additional params to draw the graph :param recursive: also show subgraphs inside operator like @see cl Scan :param prefix: prefix for every node name :param add_rt_shapes: adds shapes infered from the python runtime :param use_onnx: use :epkg:`onnx` dot format instead of this one :return: string Default options for the graph are: :: options = { 'orientation': 'portrait', 'ranksep': '0.25', 'nodesep': '0.05', 'width': '0.5', 'height': '0.1', 'size': '7', } One example: .. exref:: :title: Convert ONNX into DOT An example on how to convert an :epkg:`ONNX` graph into :epkg:`DOT`. .. runpython:: :showcode: :warningout: DeprecationWarning import numpy from skl2onnx.algebra.onnx_ops import OnnxLinearRegressor from skl2onnx.common.data_types import FloatTensorType from mlprodict.onnxrt import OnnxInference pars = dict(coefficients=numpy.array([1., 2.]), intercepts=numpy.array([1.]), post_transform='NONE') onx = OnnxLinearRegressor('X', output_names=['Y'], **pars) model_def = onx.to_onnx({'X': pars['coefficients'].astype(numpy.float32)}, outputs=[('Y', FloatTensorType([1]))], target_opset=12) oinf = OnnxInference(model_def) print(oinf.to_dot()) See an example of representation in notebook :ref:`onnxvisualizationrst`. """ clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}") clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}") def dot_name(text): return text.replace("/", "_").replace(":", "__").replace(".", "_") def dot_label(text): for reg in [clean_label_reg1, clean_label_reg2]: fall = reg.findall(text) for f in fall: text = text.replace(f, "_") # pragma: no cover return text options = { 'orientation': 'portrait', 'ranksep': '0.25', 'nodesep': '0.05', 'width': '0.5', 'height': '0.1', 'size': '7', } options.update({k: v for k, v in params.items() if v is not None}) if use_onnx: from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer pydot_graph = GetPydotGraph(self.oinf.obj.graph, name=self.oinf.obj.graph.name, rankdir=params.get('rankdir', "TB"), node_producer=GetOpNodeProducer( "docstring", fillcolor="orange", style="filled", shape="box")) return pydot_graph.to_string() inter_vars = {} exp = ["digraph{"] for opt in {'orientation', 'pad', 'nodesep', 'ranksep', 'size'}: if opt in options: exp.append(" {}={};".format(opt, options[opt])) fontsize = 10 shapes = {} if add_rt_shapes: if not hasattr(self.oinf, 'shapes_'): raise RuntimeError( # pragma: no cover "No information on shapes, check the runtime '{}'.".format( self.oinf.runtime)) for name, shape in self.oinf.shapes_.items(): va = shape.evaluate().to_string() shapes[name] = va if name in self.oinf.inplaces_: shapes[name] += "\\ninplace" # inputs exp.append("") for obj in self.oinf.obj.graph.input: dobj = _var_as_dict(obj) sh = shapes.get(dobj['name'], '') if sh: sh = "\\nshape={}".format(sh) exp.append( ' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];' .format(dot_name(dobj['name']), _type_to_string(dobj['type']), fontsize, prefix, dot_label(sh))) inter_vars[obj.name] = obj # outputs exp.append("") for obj in self.oinf.obj.graph.output: dobj = _var_as_dict(obj) sh = shapes.get(dobj['name'], '') if sh: sh = "\\nshape={}".format(sh) exp.append( ' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];' .format(dot_name(dobj['name']), _type_to_string(dobj['type']), fontsize, prefix, dot_label(sh))) inter_vars[obj.name] = obj # initializer exp.append("") for obj in self.oinf.obj.graph.initializer: dobj = _var_as_dict(obj) val = dobj['value'] flat = val.flatten() if flat.shape[0] < 9: st = str(val) else: st = str(val) if len(st) > 50: st = st[:50] + '...' st = st.replace('\n', '\\n') kind = "" exp.append( ' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];' .format(dot_name(dobj['name']), dobj['value'].dtype, dobj['value'].shape, dot_label(st), kind, fontsize, prefix)) inter_vars[obj.name] = obj # nodes fill_names = {} static_inputs = [n.name for n in self.oinf.obj.graph.input] static_inputs.extend(n.name for n in self.oinf.obj.graph.initializer) for node in self.oinf.obj.graph.node: exp.append("") for out in node.output: if len(out) > 0 and out not in inter_vars: inter_vars[out] = out sh = shapes.get(out, '') if sh: sh = "\\nshape={}".format(sh) exp.append( ' {2}{0} [shape=box label="{0}{3}" fontsize={1}];'. format(dot_name(out), fontsize, dot_name(prefix), dot_label(sh))) static_inputs.append(out) dobj = _var_as_dict(node) if dobj['name'].strip() == '': # pragma: no cover name = node.op_type iname = 1 while name in fill_names: name = "%s%d" % (name, iname) iname += 1 dobj['name'] = name node.name = name fill_names[name] = node atts = [] if 'atts' in dobj: for k, v in sorted(dobj['atts'].items()): val = None if 'value' in v: val = str(v['value']).replace("\n", "\\n").replace('"', "'") sl = max(30 - len(k), 10) if len(val) > sl: val = val[:sl] + "..." if val is not None: atts.append('{}={}'.format(k, val)) satts = "" if len(atts) == 0 else ("\\n" + "\\n".join(atts)) connects = [] if recursive and node.op_type in {'Scan', 'Loop', 'If'}: fields = (['then_branch', 'else_branch'] if node.op_type == 'If' else ['body']) for field in fields: if field not in dobj['atts']: continue # pragma: no cover # creates the subgraph body = dobj['atts'][field]['value'] oinf = self.oinf.__class__(body, runtime=self.oinf.runtime, skip_run=self.oinf.skip_run, static_inputs=static_inputs) subprefix = prefix + "B_" subdot = oinf.to_dot(recursive=recursive, prefix=subprefix, add_rt_shapes=add_rt_shapes) lines = subdot.split("\n") start = 0 for i, line in enumerate(lines): if '[' in line: start = i break subgraph = "\n".join(lines[start:]) # connecting the subgraph cluster = "cluster_{}{}_{}".format(node.op_type, id(node), id(field)) exp.append(" subgraph {} {{".format(cluster)) exp.append(' label="{0}\\n({1}){2}";'.format( dobj['op_type'], dot_name(dobj['name']), satts)) exp.append(' fontsize={0};'.format(fontsize)) exp.append(' color=black;') exp.append('\n'.join( map(lambda s: ' ' + s, subgraph.split('\n')))) node0 = body.node[0] connects.append( ("{}{}".format(dot_name(subprefix), dot_name(node0.name)), cluster)) for inp1, inp2 in zip(node.input, body.input): exp.append(" {0}{1} -> {2}{3};".format( dot_name(prefix), dot_name(inp1), dot_name(subprefix), dot_name(inp2.name))) for out1, out2 in zip(body.output, node.output): if len(out2) == 0: # Empty output, it cannot be used. continue exp.append(" {0}{1} -> {2}{3};".format( dot_name(subprefix), dot_name(out1.name), dot_name(prefix), dot_name(out2))) else: exp.append( ' {4}{1} [shape=box style="filled,rounded" color=orange label="{0}\\n({1}){2}" fontsize={3}];' .format(dobj['op_type'], dot_name(dobj['name']), satts, fontsize, dot_name(prefix))) if connects is not None and len(connects) > 0: for name, cluster in connects: exp.append(" {0}{1} -> {2} [lhead={3}];".format( dot_name(prefix), dot_name(node.name), name, cluster)) for inp in node.input: exp.append(" {0}{1} -> {0}{2};".format( dot_name(prefix), dot_name(inp), dot_name(node.name))) for out in node.output: if len(out) == 0: # Empty output, it cannot be used. continue exp.append(" {0}{1} -> {0}{2};".format( dot_name(prefix), dot_name(node.name), dot_name(out))) exp.append('}') return "\n".join(exp)