示例#1
0
def onnx2dotnb(model_onnx, width="100%", orientation="LR"):
    """
    Converts an ONNX graph into dot then into :epkg:`RenderJsDot`.
    See :ref:`onnxsklearnconsortiumrst`.
    """
    from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
    pydot_graph = GetPydotGraph(model_onnx.graph,
                                name=model_onnx.graph.name,
                                rankdir=orientation,
                                node_producer=GetOpNodeProducer(
                                    "docstring",
                                    color="yellow",
                                    fillcolor="yellow",
                                    style="filled"))
    dot = pydot_graph.to_string()
    return RenderJsDot(dot, width=width)
示例#2
0
    def to_dot(
            self,
            recursive=False,
            prefix='',  # pylint: disable=R0914
            add_rt_shapes=False,
            use_onnx=False,
            **params):
        """
        Produces a :epkg:`DOT` language string for the graph.

        :param params: additional params to draw the graph
        :param recursive: also show subgraphs inside operator like
            @see cl Scan
        :param prefix: prefix for every node name
        :param add_rt_shapes: adds shapes infered from the python runtime
        :param use_onnx: use :epkg:`onnx` dot format instead of this one
        :return: string

        Default options for the graph are:

        ::

            options = {
                'orientation': 'portrait',
                'ranksep': '0.25',
                'nodesep': '0.05',
                'width': '0.5',
                'height': '0.1',
                'size': '7',
            }

        One example:

        .. exref::
            :title: Convert ONNX into DOT

            An example on how to convert an :epkg:`ONNX`
            graph into :epkg:`DOT`.

            .. runpython::
                :showcode:
                :warningout: DeprecationWarning

                import numpy
                from skl2onnx.algebra.onnx_ops import OnnxLinearRegressor
                from skl2onnx.common.data_types import FloatTensorType
                from mlprodict.onnxrt import OnnxInference

                pars = dict(coefficients=numpy.array([1., 2.]),
                            intercepts=numpy.array([1.]),
                            post_transform='NONE')
                onx = OnnxLinearRegressor('X', output_names=['Y'], **pars)
                model_def = onx.to_onnx({'X': pars['coefficients'].astype(numpy.float32)},
                                        outputs=[('Y', FloatTensorType([1]))],
                                        target_opset=12)
                oinf = OnnxInference(model_def)
                print(oinf.to_dot())

            See an example of representation in notebook
            :ref:`onnxvisualizationrst`.
        """
        clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}")
        clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}")

        def dot_name(text):
            return text.replace("/", "_").replace(":", "__").replace(".", "_")

        def dot_label(text):
            for reg in [clean_label_reg1, clean_label_reg2]:
                fall = reg.findall(text)
                for f in fall:
                    text = text.replace(f, "_")  # pragma: no cover
            return text

        options = {
            'orientation': 'portrait',
            'ranksep': '0.25',
            'nodesep': '0.05',
            'width': '0.5',
            'height': '0.1',
            'size': '7',
        }
        options.update({k: v for k, v in params.items() if v is not None})

        if use_onnx:
            from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer

            pydot_graph = GetPydotGraph(self.oinf.obj.graph,
                                        name=self.oinf.obj.graph.name,
                                        rankdir=params.get('rankdir', "TB"),
                                        node_producer=GetOpNodeProducer(
                                            "docstring",
                                            fillcolor="orange",
                                            style="filled",
                                            shape="box"))
            return pydot_graph.to_string()

        inter_vars = {}
        exp = ["digraph{"]
        for opt in {'orientation', 'pad', 'nodesep', 'ranksep', 'size'}:
            if opt in options:
                exp.append("  {}={};".format(opt, options[opt]))
        fontsize = 10

        shapes = {}
        if add_rt_shapes:
            if not hasattr(self.oinf, 'shapes_'):
                raise RuntimeError(  # pragma: no cover
                    "No information on shapes, check the runtime '{}'.".format(
                        self.oinf.runtime))
            for name, shape in self.oinf.shapes_.items():
                va = shape.evaluate().to_string()
                shapes[name] = va
                if name in self.oinf.inplaces_:
                    shapes[name] += "\\ninplace"

        # inputs
        exp.append("")
        for obj in self.oinf.obj.graph.input:
            dobj = _var_as_dict(obj)
            sh = shapes.get(dobj['name'], '')
            if sh:
                sh = "\\nshape={}".format(sh)
            exp.append(
                '  {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];'
                .format(dot_name(dobj['name']), _type_to_string(dobj['type']),
                        fontsize, prefix, dot_label(sh)))
            inter_vars[obj.name] = obj

        # outputs
        exp.append("")
        for obj in self.oinf.obj.graph.output:
            dobj = _var_as_dict(obj)
            sh = shapes.get(dobj['name'], '')
            if sh:
                sh = "\\nshape={}".format(sh)
            exp.append(
                '  {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];'
                .format(dot_name(dobj['name']), _type_to_string(dobj['type']),
                        fontsize, prefix, dot_label(sh)))
            inter_vars[obj.name] = obj

        # initializer
        exp.append("")
        for obj in self.oinf.obj.graph.initializer:
            dobj = _var_as_dict(obj)
            val = dobj['value']
            flat = val.flatten()
            if flat.shape[0] < 9:
                st = str(val)
            else:
                st = str(val)
                if len(st) > 50:
                    st = st[:50] + '...'
            st = st.replace('\n', '\\n')
            kind = ""
            exp.append(
                '  {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];'
                .format(dot_name(dobj['name']), dobj['value'].dtype,
                        dobj['value'].shape, dot_label(st), kind, fontsize,
                        prefix))
            inter_vars[obj.name] = obj

        # nodes
        fill_names = {}
        static_inputs = [n.name for n in self.oinf.obj.graph.input]
        static_inputs.extend(n.name for n in self.oinf.obj.graph.initializer)
        for node in self.oinf.obj.graph.node:
            exp.append("")
            for out in node.output:
                if len(out) > 0 and out not in inter_vars:
                    inter_vars[out] = out
                    sh = shapes.get(out, '')
                    if sh:
                        sh = "\\nshape={}".format(sh)
                    exp.append(
                        '  {2}{0} [shape=box label="{0}{3}" fontsize={1}];'.
                        format(dot_name(out), fontsize, dot_name(prefix),
                               dot_label(sh)))
                static_inputs.append(out)

            dobj = _var_as_dict(node)
            if dobj['name'].strip() == '':  # pragma: no cover
                name = node.op_type
                iname = 1
                while name in fill_names:
                    name = "%s%d" % (name, iname)
                    iname += 1
                dobj['name'] = name
                node.name = name
                fill_names[name] = node

            atts = []
            if 'atts' in dobj:
                for k, v in sorted(dobj['atts'].items()):
                    val = None
                    if 'value' in v:
                        val = str(v['value']).replace("\n",
                                                      "\\n").replace('"', "'")
                        sl = max(30 - len(k), 10)
                        if len(val) > sl:
                            val = val[:sl] + "..."
                    if val is not None:
                        atts.append('{}={}'.format(k, val))
            satts = "" if len(atts) == 0 else ("\\n" + "\\n".join(atts))

            connects = []
            if recursive and node.op_type in {'Scan', 'Loop', 'If'}:
                fields = (['then_branch', 'else_branch']
                          if node.op_type == 'If' else ['body'])
                for field in fields:
                    if field not in dobj['atts']:
                        continue  # pragma: no cover

                    # creates the subgraph
                    body = dobj['atts'][field]['value']
                    oinf = self.oinf.__class__(body,
                                               runtime=self.oinf.runtime,
                                               skip_run=self.oinf.skip_run,
                                               static_inputs=static_inputs)
                    subprefix = prefix + "B_"
                    subdot = oinf.to_dot(recursive=recursive,
                                         prefix=subprefix,
                                         add_rt_shapes=add_rt_shapes)
                    lines = subdot.split("\n")
                    start = 0
                    for i, line in enumerate(lines):
                        if '[' in line:
                            start = i
                            break
                    subgraph = "\n".join(lines[start:])

                    # connecting the subgraph
                    cluster = "cluster_{}{}_{}".format(node.op_type, id(node),
                                                       id(field))
                    exp.append("  subgraph {} {{".format(cluster))
                    exp.append('    label="{0}\\n({1}){2}";'.format(
                        dobj['op_type'], dot_name(dobj['name']), satts))
                    exp.append('    fontsize={0};'.format(fontsize))
                    exp.append('    color=black;')
                    exp.append('\n'.join(
                        map(lambda s: '  ' + s, subgraph.split('\n'))))

                    node0 = body.node[0]
                    connects.append(
                        ("{}{}".format(dot_name(subprefix),
                                       dot_name(node0.name)), cluster))

                    for inp1, inp2 in zip(node.input, body.input):
                        exp.append("  {0}{1} -> {2}{3};".format(
                            dot_name(prefix), dot_name(inp1),
                            dot_name(subprefix), dot_name(inp2.name)))
                    for out1, out2 in zip(body.output, node.output):
                        if len(out2) == 0:
                            # Empty output, it cannot be used.
                            continue
                        exp.append("  {0}{1} -> {2}{3};".format(
                            dot_name(subprefix), dot_name(out1.name),
                            dot_name(prefix), dot_name(out2)))
            else:
                exp.append(
                    '  {4}{1} [shape=box style="filled,rounded" color=orange label="{0}\\n({1}){2}" fontsize={3}];'
                    .format(dobj['op_type'], dot_name(dobj['name']), satts,
                            fontsize, dot_name(prefix)))

            if connects is not None and len(connects) > 0:
                for name, cluster in connects:
                    exp.append("  {0}{1} -> {2} [lhead={3}];".format(
                        dot_name(prefix), dot_name(node.name), name, cluster))

            for inp in node.input:
                exp.append("  {0}{1} -> {0}{2};".format(
                    dot_name(prefix), dot_name(inp), dot_name(node.name)))
            for out in node.output:
                if len(out) == 0:
                    # Empty output, it cannot be used.
                    continue
                exp.append("  {0}{1} -> {0}{2};".format(
                    dot_name(prefix), dot_name(node.name), dot_name(out)))

        exp.append('}')
        return "\n".join(exp)