Exemple #1
0
def add_edge(dot, src, dst):
    if not dot.get_edge(src, dst):
        dot.add_edge(pydot.Edge(src, dst))
Exemple #2
0
def plot(root, filename=None):
    '''
    Walks through every node of the graph starting at ``root``,
    creates a network graph, and returns a network description. If ``filename`` is
    specified, it outputs a DOT, PNG, PDF, or SVG file depending on the file name's suffix.

    Requirements:

     * for DOT output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`__
     * for PNG, PDF, and SVG output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`__
       and `graphviz <http://graphviz.org>`__ (GraphViz executable has to be in the system's PATH).

    Args:
        node (graph node): the node to start the journey from
        filename (`str`, default None): file with extension '.dot', 'png', 'pdf', or 'svg'
         to denote what format should be written. If `None` then nothing
         will be plotted, and the returned string can be used to debug the graph.

    Returns:
        `str` describing the graph
    '''

    if filename:
        suffix = os.path.splitext(filename)[1].lower()
        if suffix not in ('.svg', '.pdf', '.png', '.dot'):
            raise ValueError('only file extensions ".svg", ".pdf", ".png", and ".dot" are supported')
    else:
        suffix = None

    if filename:
        try:
            import pydot_ng as pydot
        except ImportError:
            raise ImportError("Unable to import pydot_ng, which is required to output SVG, PDF, PNG, and DOT format.")

        # initialize a dot object to store vertices and edges
        dot_object = pydot.Dot(graph_name="network_graph", rankdir='TB')
        dot_object.set_node_defaults(shape='rectangle', fixedsize='false',
                                     style='filled',
                                     fillcolor='lightgray',
                                     height=.85, width=.85, fontsize=12)
        dot_object.set_edge_defaults(fontsize=10)

    # string to store model
    model = []

    root = root.root_function
    root_uid = root.uid
    stack = [root]
    visited = set() # [uid] instead of node object itself, as this gives us duplicate entries for nodes with multiple outputs

    primitive_op_map = {
        'Plus': '+',
        'Minus': '-',
        'ElementTimes': '*',
        'Times': '@',
    }
    function_nodes = {}  # [uid] -> dot node

    def node_desc(node):
        name = "<font point-size=\"10\" face=\"sans\">'%s'</font> <br/>"%node.name
        try:
            name += "<b><font point-size=\"14\" face=\"sans\">%s</font></b> <br/>"%node.op_name
        except AttributeError:
            pass

        name += "<font point-size=\"8\" face=\"sans\">%s</font>"%node.uid

        return '<' + name + '>'

    def shape_desc(node):
        dyn_axes = node.dynamic_axes
        dyn = '[#' + ',*' * (len(dyn_axes) - 1) + ']' if len(dyn_axes) > 0 else ''
        # the '#' indicates the batch axis, while * indicate dynamic axes (which can be sequences)
        return dyn + str(node.shape)
        static_shape = str(node.shape)
        return '"#dyn: %i\nstatic: %s"'%(num_dyn_axes, static_shape)

    while stack:
        node = stack.pop(0)

        if node.uid in visited:
            continue

        try:
            # Function node
            node = node.root_function

            stack = list(node.root_function.inputs) + stack

            # add current Function node
            def lazy_create_node(node):
                if node.uid in function_nodes: # dot node already exists
                    return function_nodes[node.uid]
                if node.is_primitive and not node.is_block and len(node.outputs) == 1 and node.output.name == node.name:     # skip the node name if redundant
                    op_name = primitive_op_map.get(node.op_name, node.op_name)
                    render_as_primitive = len(op_name) <= 4
                    size = 0.4 if render_as_primitive else 0.6
                    cur_node = pydot.Node(node.uid, label='"' + op_name + '"',
                                          shape='ellipse'  if render_as_primitive else 'box',
                                          fixedsize='true' if render_as_primitive else 'false', height=size, width=size,
                                          fontsize=20  if render_as_primitive and len(op_name) == 1 else 12 ,
                                          penwidth=4 if node.op_name != 'Pass' and node.op_name != 'ParameterOrder' else 1)
                    # TODO: Would be cool, if the user could pass a dictionary with overrides. But maybe for a later version.
                else:
                    f_name = '\n' + node.name + '()' if node.name else ''
                    cur_node = pydot.Node(node.uid, label='"' + node.op_name + f_name + '"',
                                          fixedsize='true', height=1, width=1.3,
                                          penwidth=4 if node.op_name != 'Pass' and node.op_name != 'ParameterOrder' else 1)
                dot_object.add_node(cur_node)
                function_nodes[node.uid] = cur_node
                return cur_node

            # add current node
            line = [node.op_name]
            line.append('(')

            if filename:
                cur_node = lazy_create_node(node)
                dot_object.add_node(cur_node)

            # add node's inputs
            for i, input in enumerate(node.inputs):
                # Suppress Constants inside BlockFunctions, since those are really private to the BlockFunction.
                # Still show Parameters, so users know what parameters it learns, e.g. a layer.
                from cntk import cntk_py
                if node.is_block and isinstance (input, cntk_py.Variable) and input.is_constant:
                    continue

                line.append(input.uid)
                if i != len(node.inputs) - 1:
                    line.append(', ')

                if filename:
                    if input.is_input:
                        shape = 'invhouse'
                        color = 'yellow'
                    elif input.is_placeholder:
                        shape = 'invhouse'
                        color = 'grey'
                    elif input.is_parameter:
                        shape = 'diamond'
                        color = 'green'
                    elif input.is_constant:
                        shape = 'rectangle'
                        color = 'lightblue'
                    else: # is_output
                        shape = 'invhouse'
                        color = 'grey'
                    if isinstance (input, cntk_py.Variable) and not input.is_output:
                        name = 'Parameter' if input.is_parameter else 'Constant' if input.is_constant else 'Input' if input.is_input else 'Placeholder'
                        if input.name:
                            if name == 'Parameter':  # don't say 'Parameter' for named parameters, it's already indicated by being a box
                                name = input.name
                            else:
                                name = name + '\n' + input.name
                        name += '\n' + shape_desc(input)
                        if input.is_input or input.is_placeholder: # graph inputs are eggs (since dot has no oval)
                            input_node = pydot.Node(input.uid, shape='egg', label=name, fixedsize='true', height=1, width=1.3, penwidth=4) # wish it had an oval
                        elif not input.name and input.is_constant and (input.shape == () or input.shape == (1,)): # unnamed scalar constants are just shown as values
                            input_node = pydot.Node(input.uid, shape='box', label=str(input.as_constant().value), color='white', fillcolor='white', height=0.3, width=0.4)
                        else:                                      # parameters and constants are boxes
                            input_node = pydot.Node(input.uid, shape='box', label=name, height=0.6, width=1)
                    else: # output variables never get drawn except the final output
                        assert(isinstance (input, cntk_py.Variable))
                        input_node = lazy_create_node(input.owner)  # connect to where the output comes from directly, no need to draw it
                    dot_object.add_node(input_node)
                    label = input.name if input.name else input.uid # the Output variables have no name if the function has none
                    label += '\n' + shape_desc(input)
                    dot_object.add_edge(pydot.Edge(input_node, cur_node, label=label))

            # add node's output
            line.append(') -> ')
            line = ''.join(line)

            for n in node.outputs:
                model.append(line + n.uid + ';\n')

            if (filename):
                if node.uid == root_uid: # only final network outputs are drawn
                    for output in node.outputs:
                        final_node = pydot.Node(output.uid, shape='egg', label=output.name + '\n' + shape_desc(output),
                                                fixedsize='true', height=1, width=1.3, penwidth=4)
                        dot_object.add_node(final_node)
                        dot_object.add_edge(pydot.Edge(cur_node, final_node, label=shape_desc(output)))

        except AttributeError:
            # OutputVariable node
            try:
                if node.is_output:
                    stack.insert(0, node.owner)
            except AttributeError:
                pass

        visited.add(node.uid)

    if filename:
        if suffix == '.svg':
            dot_object.write_svg(filename, prog='dot')
        elif suffix == '.pdf':
            dot_object.write_pdf(filename, prog='dot')
        elif suffix == '.png':
            dot_object.write_png(filename, prog='dot')
        else:
            dot_object.write_raw(filename)

    model = "\n".join(reversed(model))

    return model
Exemple #3
0
def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 rankdir='TB'):
    """Convert a Keras model to dot format.

    # Arguments
        model: A Keras model instance.
        show_shapes: whether to display shape information.
        show_layer_names: whether to display layer names.
        rankdir: `rankdir` argument passed to PyDot,
            a string specifying the format of the plot:
            'TB' creates a vertical plot;
            'LR' creates a horizontal plot.

    # Returns
        A `pydot.Dot` instance representing the Keras model.
    """
    from ..layers.wrappers import Wrapper
    from ..models import Sequential

    _check_pydot()
    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
        model = model.model
    layers = model.layers

    # Create graph nodes.
    for layer in layers:
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:
            try:
                outputlabels = str(layer.output_shape)
            except AttributeError:
                outputlabels = 'multiple'
            if hasattr(layer, 'input_shape'):
                inputlabels = str(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [str(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = 'multiple'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                           inputlabels,
                                                           outputlabels)
        node = pydot.Node(layer_id, label=label)
        dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer.inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model.container_nodes:
                for inbound_layer in node.inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    layer_id = str(id(layer))
                    dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
    return dot
Exemple #4
0
def output_function_graph(node,dot_file_path=None,png_file_path=None):
    '''
    Walks through every node of the graph starting at ``node``,
    creates a network graph, and saves it as a string. If dot_file_name or 
    png_file_name specified corresponding files will be saved.
    
    Requirements:

     * for DOT output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`_
     * for PNG output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`_ 
       and `graphviz <http://graphviz.org>`_

    Args:
        node (graph node): the node to start the journey from
        dot_file_path (`str`, optional): DOT file path
        png_file_path (`str`, optional): PNG file path

    Returns:
        `str` containing all nodes and edges
    '''

    dot = (dot_file_path != None)
    png = (png_file_path != None)

    if (dot or png):

        try:
            import pydot_ng as pydot
        except ImportError:
            raise ImportError("PNG and DOT format requires pydot_ng package. Unable to import pydot_ng.")

        # initialize a dot object to store vertices and edges
        dot_object = pydot.Dot(graph_name="network_graph",rankdir='TB')
        dot_object.set_node_defaults(shape='rectangle', fixedsize='false',
                                 height=.85, width=.85, fontsize=12)
        dot_object.set_edge_defaults(fontsize=10)
    
    # string to store model 
    model = ''

    # walk every node of the graph iteratively
    visitor = lambda x: True
    stack = [node]
    accum = []
    visited = set()

    while stack:
        node = stack.pop()
        
        if node in visited:
            continue

        try:
            # Function node
            node = node.root_function
            stack.extend(node.inputs)

            # add current node
            model += node.op_name + '('
            if (dot or png):
                cur_node = pydot.Node(node.op_name+' '+node.uid,label=node.op_name,shape='circle',
                                        fixedsize='true', height=1, width=1)
                dot_object.add_node(cur_node)

            # add node's inputs
            for i in range(len(node.inputs)):
                child = node.inputs[i]
                
                model += child.uid
                if (i != len(node.inputs) - 1):
                    model += ", "

                if (dot or png):
                    child_node = pydot.Node(child.uid)
                    dot_object.add_node(child_node)
                    dot_object.add_edge(pydot.Edge(child_node, cur_node,label=str(child.shape)))

            # ad node's output
            model += ") -> " + node.outputs[0].uid +'\n'

            if (dot or png):
                out_node = pydot.Node(node.outputs[0].uid)
                dot_object.add_node(out_node)
                dot_object.add_edge(pydot.Edge(cur_node,out_node,label=str(node.outputs[0].shape)))

        except AttributeError:
            # OutputVariable node
            try:
                if node.is_output:
                    stack.append(node.owner)
            except AttributeError:
                pass

    if visitor(node):
        accum.append(node)

    if (png):
        dot_object.write_png(png_file_path, prog='dot')
    if (dot):
        dot_object.write_raw(dot_file_path)

    # return lines in reversed order
    return "\n".join(model.split("\n")[::-1])
Exemple #5
0
# sentence列挙、1文ずつ処理
for sentence in root.iterfind('./document/sentences/sentence'):
    sent_id = int(sentence.get('id'))
    edges = []

    # dependencies列挙
    for dep in sentence.iterfind(
            './dependencies[@type="collapsed-dependencies"]/dep'):
        # 句読点はスキップ
        if dep.get('type') != 'punct':
            # governor、dependent取得、edgesに追加
            govr = dep.find('./governor')
            dept = dep.find('./dependent')
            edges.append(
                ((govr.get('idx'), govr.text), (dept.get('idx'), dept.text)))

    # 描画
    if len(edges) > 0:
        graph = pydot.Dot(graph_type='digraph')
        for edge in edges:
            id1 = str(edge[0][0])
            label1 = str(edge[0][1])
            id2 = str(edge[1][0])
            label2 = str(edge[1][1])
            # ノード追加
            graph.add_node(pydot.Node(id1, label=label1))
            graph.add_node(pydot.Node(id2, label=label2))
            # エッジ追加
            graph.add_edge(pydot.Edge(id1, id2))
        graph.write_png('./file/{}.png'.format(sent_id))
Exemple #6
0
def pydotprint(
    fct,
    outfile=None,
    compact=True,
    format="png",
    with_ids=False,
    high_contrast=True,
    cond_highlight=None,
    colorCodes=None,
    max_label_size=70,
    scan_graphs=False,
    var_with_name_simple=False,
    print_output_file=True,
    return_image=False,
):
    """Print to a file the graph of a compiled theano function's ops. Supports
    all pydot output formats, including png and svg.

    :param fct: a compiled Theano function, a Variable, an Apply or
                a list of Variable.
    :param outfile: the output file where to put the graph.
    :param compact: if True, will remove intermediate var that don't have name.
    :param format: the file format of the output.
    :param with_ids: Print the toposort index of the node in the node name.
                     and an index number in the variable ellipse.
    :param high_contrast: if true, the color that describes the respective
            node is filled with its corresponding color, instead of coloring
            the border
    :param colorCodes: dictionary with names of ops as keys and colors as
            values
    :param cond_highlight: Highlights a lazy if by surrounding each of the 3
                possible categories of ops with a border. The categories
                are: ops that are on the left branch, ops that are on the
                right branch, ops that are on both branches
                As an alternative you can provide the node that represents
                the lazy if
    :param scan_graphs: if true it will plot the inner graph of each scan op
                in files with the same name as the name given for the main
                file to which the name of the scan op is concatenated and
                the index in the toposort of the scan.
                This index can be printed with the option with_ids.
    :param var_with_name_simple: If true and a variable have a name,
                we will print only the variable name.
                Otherwise, we concatenate the type to the var name.
    :param return_image: If True, it will create the image and return it.
        Useful to display the image in ipython notebook.

        .. code-block:: python

            import theano
            v = theano.tensor.vector()
            from IPython.display import SVG
            SVG(theano.printing.pydotprint(v*2, return_image=True,
                                           format='svg'))

    In the graph, ellipses are Apply Nodes (the execution of an op)
    and boxes are variables.  If variables have names they are used as
    text (if multiple vars have the same name, they will be merged in
    the graph).  Otherwise, if the variable is constant, we print its
    value and finally we print the type + a unique number to prevent
    multiple vars from being merged.  We print the op of the apply in
    the Apply box with a number that represents the toposort order of
    application of those Apply.  If an Apply has more than 1 input, we
    label each edge between an input and the Apply node with the
    input's index.

    Variable color code::
        - Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
        - Green boxes are inputs variables to the graph,
        - Blue boxes are outputs variables of the graph,
        - Grey boxes are variables that are not outputs and are not used,

    Default apply node code::
        - Red ellipses are transfers from/to the gpu
        - Yellow are scan node
        - Brown are shape node
        - Magenta are IfElse node
        - Dark pink are elemwise node
        - Purple are subtensor
        - Orange are alloc node

    For edges, they are black by default. If a node returns a view
    of an input, we put the corresponding input edge in blue. If it
    returns a destroyed input, we put the corresponding edge in red.

    .. note::

        Since October 20th, 2014, this print the inner function of all
        scan separately after the top level debugprint output.

    """
    from theano.scan.op import Scan

    if colorCodes is None:
        colorCodes = default_colorCodes

    if outfile is None:
        outfile = os.path.join(
            config.compiledir, "theano.pydotprint." + config.device + "." + format
        )

    if isinstance(fct, Function):
        profile = getattr(fct, "profile", None)
        fgraph = fct.maker.fgraph
        outputs = fgraph.outputs
        topo = fgraph.toposort()
    elif isinstance(fct, gof.FunctionGraph):
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    else:
        if isinstance(fct, gof.Variable):
            fct = [fct]
        elif isinstance(fct, gof.Apply):
            fct = fct.outputs
        assert isinstance(fct, (list, tuple))
        assert all(isinstance(v, gof.Variable) for v in fct)
        fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct)
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    if not pydot_imported:
        raise RuntimeError(
            "Failed to import pydot. You must install graphviz"
            " and either pydot or pydot-ng for "
            "`pydotprint` to work.",
            pydot_imported_msg,
        )

    g = pd.Dot()

    if cond_highlight is not None:
        c1 = pd.Cluster("Left")
        c2 = pd.Cluster("Right")
        c3 = pd.Cluster("Middle")
        cond = None
        for node in topo:
            if (
                node.op.__class__.__name__ == "IfElse"
                and node.op.name == cond_highlight
            ):
                cond = node
        if cond is None:
            _logger.warning(
                "pydotprint: cond_highlight is set but there is no"
                " IfElse node in the graph"
            )
            cond_highlight = None

    if cond_highlight is not None:

        def recursive_pass(x, ls):
            if not x.owner:
                return ls
            else:
                ls += [x.owner]
                for inp in x.inputs:
                    ls += recursive_pass(inp, ls)
                return ls

        left = set(recursive_pass(cond.inputs[1], []))
        right = set(recursive_pass(cond.inputs[2], []))
        middle = left.intersection(right)
        left = left.difference(middle)
        right = right.difference(middle)
        middle = list(middle)
        left = list(left)
        right = list(right)

    var_str = {}
    var_id = {}
    all_strings = set()

    def var_name(var):
        if var in var_str:
            return var_str[var], var_id[var]

        if var.name is not None:
            if var_with_name_simple:
                varstr = var.name
            else:
                varstr = "name=" + var.name + " " + str(var.type)
        elif isinstance(var, gof.Constant):
            dstr = "val=" + str(np.asarray(var.data))
            if "\n" in dstr:
                dstr = dstr[: dstr.index("\n")]
            varstr = f"{dstr} {var.type}"
        elif var in input_update and input_update[var].name is not None:
            varstr = input_update[var].name
            if not var_with_name_simple:
                varstr += str(var.type)
        else:
            # a var id is needed as otherwise var with the same type will be
            # merged in the graph.
            varstr = str(var.type)
        if len(varstr) > max_label_size:
            varstr = varstr[: max_label_size - 3] + "..."
        var_str[var] = varstr
        var_id[var] = str(id(var))

        all_strings.add(varstr)

        return varstr, var_id[var]

    apply_name_cache = {}
    apply_name_id = {}

    def apply_name(node):
        if node in apply_name_cache:
            return apply_name_cache[node], apply_name_id[node]
        prof_str = ""
        if profile:
            time = profile.apply_time.get((fgraph, node), 0)
            # second, %fct time in profiler
            if profile.fct_callcount == 0 or profile.fct_call_time == 0:
                pf = 0
            else:
                pf = time * 100 / profile.fct_call_time
            prof_str = f"   ({time:.3f}s,{pf:.3f}%)"
        applystr = str(node.op).replace(":", "_")
        applystr += prof_str
        if (applystr in all_strings) or with_ids:
            idx = " id=" + str(topo.index(node))
            if len(applystr) + len(idx) > max_label_size:
                applystr = applystr[: max_label_size - 3 - len(idx)] + idx + "..."
            else:
                applystr = applystr + idx
        elif len(applystr) > max_label_size:
            applystr = applystr[: max_label_size - 3] + "..."
            idx = 1
            while applystr in all_strings:
                idx += 1
                suffix = " id=" + str(idx)
                applystr = applystr[: max_label_size - 3 - len(suffix)] + "..." + suffix

        all_strings.add(applystr)
        apply_name_cache[node] = applystr
        apply_name_id[node] = str(id(node))

        return applystr, apply_name_id[node]

    # Update the inputs that have an update function
    input_update = {}
    reverse_input_update = {}
    # Here outputs can be the original list, as we should not change
    # it, we must copy it.
    outputs = list(outputs)
    if isinstance(fct, Function):
        function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
        for i, fg_ii in reversed(list(function_inputs)):
            if i.update is not None:
                k = outputs.pop()
                # Use the fgaph.inputs as it isn't the same as maker.inputs
                input_update[k] = fg_ii
                reverse_input_update[fg_ii] = k

    apply_shape = "ellipse"
    var_shape = "box"
    for node_idx, node in enumerate(topo):
        astr, aid = apply_name(node)

        use_color = None
        for opName, color in colorCodes.items():
            if opName in node.op.__class__.__name__:
                use_color = color

        if use_color is None:
            nw_node = pd.Node(aid, label=astr, shape=apply_shape)
        elif high_contrast:
            nw_node = pd.Node(
                aid, label=astr, style="filled", fillcolor=use_color, shape=apply_shape
            )
        else:
            nw_node = pd.Node(aid, label=astr, color=use_color, shape=apply_shape)
        g.add_node(nw_node)
        if cond_highlight:
            if node in middle:
                c3.add_node(nw_node)
            elif node in left:
                c1.add_node(nw_node)
            elif node in right:
                c2.add_node(nw_node)

        for idx, var in enumerate(node.inputs):
            varstr, varid = var_name(var)
            label = ""
            if len(node.inputs) > 1:
                label = str(idx)
            param = {}
            if label:
                param["label"] = label
            if hasattr(node.op, "view_map") and idx in reduce(
                list.__add__, node.op.view_map.values(), []
            ):
                param["color"] = colorCodes["Output"]
            elif hasattr(node.op, "destroy_map") and idx in reduce(
                list.__add__, node.op.destroy_map.values(), []
            ):
                param["color"] = "red"
            if var.owner is None:
                color = "green"
                if isinstance(var, SharedVariable):
                    # Input are green, output blue
                    # Mixing blue and green give cyan! (input and output var)
                    color = "cyan"
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            fillcolor=color,
                            label=varstr,
                            shape=var_shape,
                        )
                    )
                else:
                    g.add_node(
                        pd.Node(varid, color=color, label=varstr, shape=var_shape)
                    )
                g.add_edge(pd.Edge(varid, aid, **param))
            elif var.name or not compact or var in outputs:
                g.add_edge(pd.Edge(varid, aid, **param))
            else:
                # no name, so we don't make a var ellipse
                if label:
                    label += " "
                label += str(var.type)
                if len(label) > max_label_size:
                    label = label[: max_label_size - 3] + "..."
                param["label"] = label
                g.add_edge(pd.Edge(apply_name(var.owner)[1], aid, **param))

        for idx, var in enumerate(node.outputs):
            varstr, varid = var_name(var)
            out = var in outputs
            label = ""
            if len(node.outputs) > 1:
                label = str(idx)
            if len(label) > max_label_size:
                label = label[: max_label_size - 3] + "..."
            param = {}
            if label:
                param["label"] = label
            if out or var in input_update:
                g.add_edge(pd.Edge(aid, varid, **param))
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor=colorCodes["Output"],
                            shape=var_shape,
                        )
                    )
                else:
                    g.add_node(
                        pd.Node(
                            varid,
                            color=colorCodes["Output"],
                            label=varstr,
                            shape=var_shape,
                        )
                    )
            elif len(fgraph.clients[var]) == 0:
                g.add_edge(pd.Edge(aid, varid, **param))
                # grey mean that output var isn't used
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor="grey",
                            shape=var_shape,
                        )
                    )
                else:
                    g.add_node(
                        pd.Node(varid, label=varstr, color="grey", shape=var_shape)
                    )
            elif var.name or not compact:
                if not (not compact):
                    if label:
                        label += " "
                    label += str(var.type)
                    if len(label) > max_label_size:
                        label = label[: max_label_size - 3] + "..."
                    param["label"] = label
                g.add_edge(pd.Edge(aid, varid, **param))
                g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
    #            else:
    # don't add egde here as it is already added from the inputs.

    # The var that represent updates, must be linked to the input var.
    for sha, up in input_update.items():
        _, shaid = var_name(sha)
        _, upid = var_name(up)
        g.add_edge(pd.Edge(shaid, upid, label="UPDATE", color=colorCodes["Output"]))

    if cond_highlight:
        g.add_subgraph(c1)
        g.add_subgraph(c2)
        g.add_subgraph(c3)

    if not outfile.endswith("." + format):
        outfile += "." + format

    if scan_graphs:
        scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)]
        path, fn = os.path.split(outfile)
        basename = ".".join(fn.split(".")[:-1])
        # Safe way of doing things .. a file name may contain multiple .
        ext = fn[len(basename) :]

        for idx, scan_op in scan_ops:
            # is there a chance that name is not defined?
            if hasattr(scan_op.op, "name"):
                new_name = basename + "_" + scan_op.op.name + "_" + str(idx)
            else:
                new_name = basename + "_" + str(idx)
            new_name = os.path.join(path, new_name + ext)
            if hasattr(scan_op.op, "fn"):
                to_print = scan_op.op.fn
            else:
                to_print = scan_op.op.outputs
            pydotprint(
                to_print,
                new_name,
                compact,
                format,
                with_ids,
                high_contrast,
                cond_highlight,
                colorCodes,
                max_label_size,
                scan_graphs,
            )

    if return_image:
        return g.create(prog="dot", format=format)
    else:
        try:
            g.write(outfile, prog="dot", format=format)
        except pd.InvocationException:
            # based on https://github.com/Theano/Theano/issues/2988
            version = getattr(pd, "__version__", "")
            if version and [int(n) for n in version.split(".")] < [1, 0, 28]:
                raise Exception(
                    "Old version of pydot detected, which can "
                    "cause issues with pydot printing. Try "
                    "upgrading pydot version to a newer one"
                )
            raise

        if print_output_file:
            print("The output file is available at", outfile)
Exemple #7
0
def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 show_params=False,
                 rankdir='TB',
                 **kwargs):
    """Convert a Keras model to dot format.
    # Arguments
        model: A Keras model instance.
        show_shapes: whether to display shape information.
        show_layer_names: whether to display layer names.
        show_params: show details of params
        rankdir: `rankdir` argument passed to PyDot,
            a string specifying the format of the plot:
            'TB' creates a vertical plot;
            'LR' creates a horizontal plot.
    # Returns
        A `pydot.Dot` instance representing the Keras model.
    """
    from keras.layers.wrappers import Wrapper
    from keras.models import Sequential

    _check_pydot()
    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
        model = model.model
    layers = model.layers

    attrs = [
        'filters', 'padding', 'use_bias', 'kernel_size', 'strides',
        'pool_size', 'size', 'rate', 'dims', 'n', 'units', 'l1', 'l2',
        'supporting_masking', 'epsilon', 'scale', 'momentum', 'dilation_rate'
    ]
    if 'attrs' in kwargs.keys():
        extra_attrs = kwargs['attrs']
        if isinstance(extra_attrs, list):
            attrs += extra_attrs
        else:
            raise TypeError(
                'extra attributes can only be list, given {}'.format(
                    type(extra_attrs)))
    # Create graph nodes.
    for layer in layers:
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        color = layer2color(class_name)
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        if show_params:
            attr_label = None
            for attr in attrs:
                if hasattr(layer, attr):
                    if attr_label is not None:
                        attr_label = '%s\n%s: %s' % (attr_label, attr,
                                                     getattr(layer, attr))
                    else:
                        attr_label = '%s: %s' % (attr, getattr(layer, attr))
            if attr_label is not None:
                label = '%s\n%s' % (label, attr_label)

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:
            try:
                outputlabels = str(layer.output_shape)
            except AttributeError:
                outputlabels = 'multiple'
            if hasattr(layer, 'input_shape'):
                inputlabels = str(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [str(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = 'multiple'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
                                                           outputlabels)
        node = pydot.Node(layer_id,
                          label=label,
                          fillcolor=color,
                          style='filled')
        dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer.inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model.container_nodes:
                for inbound_layer in node.inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    layer_id = str(id(layer))
                    dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
    return dot
Exemple #8
0
def add_edge(dot, src, dst, output_shape=None):
    if not dot.get_edge(src, dst):
        if output_shape:
            dot.add_edge(pydot.Edge(src, dst, label=output_shape))
        else:
            dot.add_edge(pydot.Edge(src, dst))
Exemple #9
0
    def plot_apply(app, d):
        if d == 0:
            return
        if app in my_list:
            return
        astr = apply_name(app) + '_' + str(len(my_list.keys()))
        if len(astr) > max_label_size:
            astr = astr[:max_label_size - 3] + '...'
        my_list[app] = astr

        use_color = None
        for opName, color in iteritems(colorCodes):
            if opName in app.op.__class__.__name__:
                use_color = color

        if use_color is None:
            g.add_node(pd.Node(astr, shape='box'))
        elif high_contrast:
            g.add_node(
                pd.Node(astr, style='filled', fillcolor=use_color,
                        shape='box'))
        else:
            g.add_node(pd.Nonde(astr, color=use_color, shape='box'))

        for i, nd in enumerate(app.inputs):
            if nd not in my_list:
                varastr = var_name(nd) + '_' + str(len(my_list.keys()))
                if len(varastr) > max_label_size:
                    varastr = varastr[:max_label_size - 3] + '...'
                my_list[nd] = varastr
                if nd.owner is not None:
                    g.add_node(pd.Node(varastr))
                elif high_contrast:
                    g.add_node(
                        pd.Node(varastr, style='filled', fillcolor='green'))
                else:
                    g.add_node(pd.Node(varastr, color='green'))
            else:
                varastr = my_list[nd]
            label = None
            if len(app.inputs) > 1:
                label = str(i)
            g.add_edge(pd.Edge(varastr, astr, label=label))

        for i, nd in enumerate(app.outputs):
            if nd not in my_list:
                varastr = var_name(nd) + '_' + str(len(my_list.keys()))
                if len(varastr) > max_label_size:
                    varastr = varastr[:max_label_size - 3] + '...'
                my_list[nd] = varastr
                color = None
                if nd in vars:
                    color = colorCodes['Output']
                elif nd in orphanes:
                    color = 'gray'
                if color is None:
                    g.add_node(pd.Node(varastr))
                elif high_contrast:
                    g.add_node(
                        pd.Node(varastr, style='filled', fillcolor=color))
                else:
                    g.add_node(pd.Node(varastr, color=color))
            else:
                varastr = my_list[nd]
            label = None
            if len(app.outputs) > 1:
                label = str(i)
            g.add_edge(pd.Edge(astr, varastr, label=label))
        for nd in app.inputs:
            if nd.owner:
                plot_apply(nd.owner, d - 1)
Exemple #10
0
def plot_graph(graph, graph_img_path='graph.png', show_coreml_mapped_shapes=False):
    """
    Plot graph using pydot

    It works in two steps:
    1. Add nodes to pydot
    2. connect nodes added in pydot

    :param graph
    :return: writes down a png/pdf file using dot 
    """

    try:
        # pydot-ng is a fork of pydot that is better maintained.
        import pydot_ng as pydot # type: ignore
    except:
        # pydotplus is an improved version of pydot
        try:
            import pydotplus as pydot # type: ignore
        except:
            # Fall back on pydot if necessary.
            try:
                import pydot # type: ignore
            except:
                return None

    dot = pydot.Dot()
    dot.set('rankdir', 'TB')
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    # Add nodes corresponding to graph inputs
    graph_inputs = []
    for input_ in graph.inputs:
        if show_coreml_mapped_shapes:
            if input_[0] in graph.onnx_coreml_shape_mapping:
                shape = tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_[0]]))
            else:
                shape = 'NA, '
        else:
            shape = tuple(input_[2])
        label = '%s\n|{|%s}|{{%s}|{%s}}' % ('Input',
                                            input_[0],
                                            '',
                                            str(shape))
        pydot_node = pydot.Node(input_[0], label=label)
        dot.add_node(pydot_node)
        graph_inputs.append(input_[0])

    # Traverse graph and add nodes to pydot
    for node in graph.nodes:
        inputlabels = ''
        for input_ in node.inputs:
            if show_coreml_mapped_shapes:
                if input_ in graph.onnx_coreml_shape_mapping:
                    inputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[input_]))) + ', '
                else:
                    inputlabels += 'NA, '
            else:
                if input_ in graph.shape_dict:
                    inputlabels += str(tuple(graph.shape_dict[input_])) + ', '
                else:
                    inputlabels += 'NA, '
        outputlabels = ''
        for output_ in node.outputs:
            if show_coreml_mapped_shapes:
                if output_ in graph.onnx_coreml_shape_mapping:
                    outputlabels += str(tuple(_shape_notation(graph.onnx_coreml_shape_mapping[output_]))) + ', '
                else:
                    outputlabels += 'NA, '
            else:
                if output_ in graph.shape_dict:
                    outputlabels += str(tuple(graph.shape_dict[output_])) + ', '
                else:
                    outputlabels += 'NA, '
        output_names = ', '.join([output_ for output_ in node.outputs])
        input_names = ', '.join([input_ for input_ in node.inputs])
        label = '%s\n|{{%s}|{%s}}|{{%s}|{%s}}' % (node.op_type,
                                                  input_names,
                                                  output_names,
                                                  inputlabels,
                                                  outputlabels)
        pydot_node = pydot.Node(node.name, label=label)
        dot.add_node(pydot_node)

    # add edges
    for node in graph.nodes:
        for child in node.children:
            # add edge in pydot
            dot.add_edge(pydot.Edge(node.name, child.name))
        for input_ in node.inputs:
            if input_ in graph_inputs:
                dot.add_edge(pydot.Edge(input_, node.name))


    # write out the image file
    _, extension = os.path.splitext(graph_img_path)
    if not extension:
        extension = 'pdf'
    else:
        extension = extension[1:]
    dot.write(graph_img_path, format=extension)
def to_pydot(N, strict=True):
    """Return a pydot graph from a NetworkX graph N.

    Parameters
    ----------
    N : NetworkX graph
      A graph created with NetworkX

    Examples
    --------
    >>> import networkx as nx
    >>> K5 = nx.complete_graph(5)
    >>> P = nx.to_pydot(K5)

    Notes
    -----


    """
    # set Graphviz graph type
    if N.is_directed():
        graph_type = 'digraph'
    else:
        graph_type = 'graph'
    strict = N.number_of_selfloops() == 0 and not N.is_multigraph()

    name = N.graph.get('name')
    graph_defaults = N.graph.get('graph', {})
    if name is None:
        P = pydot.Dot(graph_type=graph_type, strict=strict, **graph_defaults)
    else:
        P = pydot.Dot('"%s"' % name,
                      graph_type=graph_type,
                      strict=strict,
                      **graph_defaults)
    try:
        P.set_node_defaults(**N.graph['node'])
    except KeyError:
        pass
    try:
        P.set_edge_defaults(**N.graph['edge'])
    except KeyError:
        pass

    for n, nodedata in N.nodes_iter(data=True):
        str_nodedata = dict((k, make_str(v)) for k, v in nodedata.items())
        p = pydot.Node(make_str(n), **str_nodedata)
        P.add_node(p)

    if N.is_multigraph():
        for u, v, key, edgedata in N.edges_iter(data=True, keys=True):
            str_edgedata = dict((k, make_str(v)) for k, v in edgedata.items())
            edge = pydot.Edge(make_str(u),
                              make_str(v),
                              key=make_str(key),
                              **str_edgedata)
            P.add_edge(edge)
    else:
        for u, v, edgedata in N.edges_iter(data=True):
            str_edgedata = dict((k, make_str(v)) for k, v in edgedata.items())
            edge = pydot.Edge(make_str(u), make_str(v), **str_edgedata)
            P.add_edge(edge)
    return P
Exemple #12
0
def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 rankdir='TB'):
    """Convert a Keras model to dot format.

  Arguments:
      model: A Keras model instance.
      show_shapes: whether to display shape information.
      show_layer_names: whether to display layer names.
      rankdir: `rankdir` argument passed to PyDot,
          a string specifying the format of the plot:
          'TB' creates a vertical plot;
          'LR' creates a horizontal plot.

  Returns:
      A `pydot.Dot` instance representing the Keras model (or None if the Dot
      file could not be generated).

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
    from tensorflow.python.keras.layers.wrappers import Wrapper
    from tensorflow.python.keras.models import Sequential
    from tensorflow.python.util import nest

    check = _check_pydot()
    if not check:
        if 'IPython.core.magics.namespace' in sys.modules:
            # We don't raise an exception here in order to avoid crashing notebook
            # tests where graphviz is not available.
            print('Failed to import pydot. You must install pydot'
                  ' and graphviz for `pydotprint` to work.')
            return
        else:
            raise ImportError('Failed to import pydot. You must install pydot'
                              ' and graphviz for `pydotprint` to work.')

    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
    layers = model._layers

    # Create graph nodes.
    for layer in layers:
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:
            try:
                outputlabels = str(layer.output_shape)
            except AttributeError:
                outputlabels = 'multiple'
            if hasattr(layer, 'input_shape'):
                inputlabels = str(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [str(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = 'multiple'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
                                                           outputlabels)
        node = pydot.Node(layer_id, label=label)
        dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:  # pylint: disable=protected-access
                for inbound_layer in nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    layer_id = str(id(layer))
                    dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
    return dot
Exemple #13
0
def get_pydot_graph(layers, output_shape=True, rankdir="LR"):
    """
    Creates a PyDot graph of the network defined by the given layers.
    :parameters:
        - layers : list
            List of the layers, as obtained from lasange.layers.get_all_layers
        - output_shape: (default `True`)
            If `True`, the output shape of each layer will be displayed.
        - verbose: (default `False`)
            If `True`, layer attributes like filter shape, stride, etc.
            will be displayed.
        - verbose:
    :returns:
        - pydot_graph : PyDot object containing the graph

    """
    pydot_graph = pydot.Dot('Network',
                            graph_type='digraph',
                            rankdir=rankdir,
                            bgcolor='transparent')
    pydot_nodes = {}
    pydot_edges = []
    for i, layer in enumerate(layers):
        layer_type = '{0}'.format(layer.__class__.__name__)
        key = repr(layer)
        label = layer_type
        color = get_hex_color(layer_type)
        shape = 'record'
        for attr in [
                'num_filters', 'num_units', 'ds', 'filter_shape', 'stride',
                'strides', 'p'
        ]:
            if hasattr(layer, attr):
                label += '\n' + \
                    '{0}: {1}'.format(attr, getattr(layer, attr))

        if output_shape:
            label += '\n' + \
                'Output shape: {0}'.format(layer.output_shape)

        if hasattr(layer, 'nonlinearity'):
            try:
                nonlinearity = layer.nonlinearity.__name__
            except AttributeError:
                nonlinearity = layer.nonlinearity.__class__.__name__
            label = "<<TABLE CELLSPACING='0' CELLPADDING='0' BORDER='0' CELLBORDER='1'><TR><TD BGCOLOR='%s'>" % color + label.replace(
                "\n", "<BR/>") + "</TD>"
            if rankdir != "LR":
                label += "</TR><TR>"
            label += '<TD BGCOLOR="#FFFF00">' + '{0}'.format(
                nonlinearity) + "</TD></TR></TABLE>>"
            color = "none"
            #label = "<f0> "+label+'| <f1>' + 'nonlinearity: {0}'.format(nonlinearity)
            shape = 'none'
        pydot_nodes[key] = pydot.Node(
            key,
            label=label,
            shape=shape,
            fillcolor=color,
            style='filled',
        )

        if hasattr(layer, 'input_layers'):
            for input_layer in layer.input_layers:
                pydot_edges.append([repr(input_layer), key])

        if hasattr(layer, 'input_layer'):
            pydot_edges.append([repr(layer.input_layer), key])

    for node in pydot_nodes.values():
        pydot_graph.add_node(node)
    for edge in pydot_edges:
        pydot_graph.add_edge(
            pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
    return pydot_graph
Exemple #14
0
    def writeDotGraph(self, graph, startNode=0):
        """ Write a graph to the pydot Graph instance

        :param graph: the pydot Graph instance
        :param startNode: used to plot more than one individual
        """
        from . import Consts

        if not HAVE_PYDOT:
            print("You must install Pydot to use this feature !")
            return

        count = startNode
        node_stack = []
        nodes_dict = {}
        import __main__ as main_module

        for i in range(len(self.nodes_list)):
            newnode = pydot.Node(str(count), style="filled")
            count += 1

            if self.nodes_list[i].getType() == Consts.nodeType["TERMINAL"]:
                newnode.set_color("lightblue2")
            else:
                newnode.set_color("goldenrod2")

            if self.nodes_list[i].getType() == Consts.nodeType["NONTERMINAL"]:
                func = getattr(main_module, self.nodes_list[i].getData())

                if hasattr(func, "shape"):
                    newnode.set_shape(func.shape)

                if hasattr(func, "representation"):
                    newnode.set_label(func.representation)
                else:
                    newnode.set_label(self.nodes_list[i].getData())
                if hasattr(func, "color"):
                    newnode.set_color(func.color)

            else:
                newnode.set_label(self.nodes_list[i].getData())

            nodes_dict.update({self.nodes_list[i]: newnode})
            graph.add_node(newnode)

        node_stack.append(self.getRoot())
        while len(node_stack) > 0:
            tmp = node_stack.pop()

            parent = tmp.getParent()
            if parent is not None:
                parent_node = nodes_dict[parent]
                child_node = nodes_dict[tmp]

                newedge = pydot.Edge(parent_node, child_node)
                graph.add_edge(newedge)

            rev_childs = tmp.getChilds()[:]
            rev_childs.reverse()
            node_stack.extend(rev_childs)

        return count
Exemple #15
0
def model_to_dot(model, rankdir='TB'):
    """Convert a Keras model to dot format.
    # Arguments
        model: A Keras model instance.
        rankdir: `rankdir` argument passed to PyDot,
            a string specifying the format of the plot:
            'TB' creates a vertical plot;
            'LR' creates a horizontal plot.
    # Returns
        A `pydot.Dot` instance representing the Keras model.
    """
    from keras.layers.wrappers import Wrapper
    from keras.models import Sequential

    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
        model = model.model
    layers = model.layers

    # Create graph nodes.
    for layer in layers:
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)

        if class_name == "InputLayer":
            class_name = "Input"

        # Create node's label.
        label = class_name

        # Add Dense
        try:
            label += " " + str(layer.units)
        except AttributeError:
            # Add Convolutions
            if isinstance(layer, keras.layers.convolutional._Conv):
                kernel_size = "x".join([str(k) for k in layer.kernel_size])
                label += " %s,%s" % (kernel_size, str(layer.filters))

            # Add pool1d
            if isinstance(layer, keras.layers.pooling._Pooling1D):
                label += " " + str(layer.pool_size)

            # Add pool2d
            if isinstance(layer, keras.layers.pooling._Pooling2D):
                pool_size = [str(k) for k in layer.pool_size]
                label += " " + "x".join(pool_size)

        node = pydot.Node(layer_id, label=label)
        node.set("shape", 'box')
        if class_name == "Input":
            node.set("color", 'red')

        dot.add_node(node)

    # add output node
    output_node = pydot.Node("output_node", label="Output")
    output_node.set("shape", 'box')
    dot.add_node(output_node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer.inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                inbound_layers = node.inbound_layers
                if not isinstance(inbound_layers, list):
                    inbound_layers = [inbound_layers]
                for inbound_layer in inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    layer_id = str(id(layer))
                    output_shape = inbound_layer.output_shape
                    if isinstance(output_shape, list):
                        if len(output_shape) > 1:
                            raise Exception("More than one output_shape found")
                        output_shape = output_shape[0]
                    label = str(output_shape[1:])
                    edge = pydot.Edge(inbound_layer_id, layer_id, label=label)
                    dot.add_edge(edge)

    # connect output
    out_edge = pydot.Edge(str(id(layers[-1])),
                          "output_node",
                          label=str(model.output_shape[1:]))
    dot.add_edge(out_edge)

    return dot