Beispiel #1
0
    def writePopulationDotRaw(ga_engine, filename, start=0, end=0):
        """ Writes to a raw dot file using pydot, the population of trees

        Example:
           >>> GTreeGP.writePopulationDotRaw(ga_engine, "pop.dot", 0, 10)

        This example will draw the first ten individuals of the population into
        the file called "pop.dot".

        :param ga_engine: the GA Engine
        :param filename: the filename, ie. population.dot
        :param start: the start index of individuals
        :param end: the end index of individuals
        """
        if not HAVE_PYDOT:
            Util.raiseException("You must install Pydot to use this feature !")

        pop = ga_engine.getPopulation()
        graph = pydot.Dot(graph_type="digraph")

        if not isinstance(pop[0], GTreeGP):
            Util.raiseException("The population must have individuals of the GTreeGP chromosome !")

        n = 0
        end_index = len(pop) if end == 0 else end
        for i in range(start, end_index):
            ind = pop[i]
            subg = pydot.Cluster(
                "cluster_%d" % i,
                label="\"Ind. #%d - Score Raw/Fit.: %.4f/%.4f\"" % (i, ind.getRawScore(), ind.getFitnessScore())
            )
            n = ind.writeDotGraph(subg, n)
            graph.add_subgraph(subg)

        graph.write(filename, prog='dot', format="raw")
def draw_legend(graph):
    graphlegend = pydot.Cluster(graph_name="legend",
                                label="Legend",
                                fontsize="20",
                                color="red",
                                fontcolor="black",
                                style="filled",
                                fillcolor="white")

    legend1 = pydot.Node('Processed node', shape="plaintext")
    graphlegend.add_node(legend1)
    legend2 = pydot.Node("Depth limit reached", shape="plaintext")
    graphlegend.add_node(legend2)
    legend3 = pydot.Node('Goal Node', shape="plaintext")
    graphlegend.add_node(legend3)

    node1 = pydot.Node("1", style="filled", fillcolor="coral", label="")
    graphlegend.add_node(node1)
    node2 = pydot.Node("2", style="filled", fillcolor="springgreen", label="")
    graphlegend.add_node(node2)
    node3 = pydot.Node("3", style="filled", fillcolor="aquamarine2", label="")
    graphlegend.add_node(node3)

    graph.add_subgraph(graphlegend)
    graph.add_edge(pydot.Edge(legend1, legend2, style="invis"))
    graph.add_edge(pydot.Edge(legend2, legend3, style="invis"))

    graph.add_edge(pydot.Edge(node1, node2, style="invis"))
    graph.add_edge(pydot.Edge(node2, node3, style="invis"))
Beispiel #3
0
def to_pydot_graph(dag, sub_graph=False, input_edge=None):
    import pydot_ng as pydot
    gid = dag['input'][2][:-2]
    graph = pydot.Dot(graph_type='digraph')
    if sub_graph:
        graph = pydot.Cluster(label='booster')
    dag_nx = dag_to_nx(dag)

    for n in dag_nx.nodes():
        label = 'input' if dag[n][1] == 'input' else (
            dag[n][1][0] + ('(' + ','.join('{}={}'.format(k, v)
                                           for k, v in dag[n][1][1].items()) +
                            ')' if dag[n][1][1] else ''))
        label = 'input' if dag[n][1] == 'input' else (
            dag[n][1][0] + ('(' + ','.join('{}'.format(v)
                                           for k, v in dag[n][1][1].items()) +
                            ')' if dag[n][1][1] else ''))
        node_name = n
        if label == 'input':
            if sub_graph:
                continue
            node_name = label + gid
            if input_edge != None:
                outs = dag[n][2]
                if not isinstance(outs, list):
                    outs = [outs]
                for o in outs:
                    edge = pydot.Edge(input_edge, o)
                    graph.add_edge(edge)

        if dag[n][1][0] == 'booster':
            graph.add_node(pydot.Node(n, label='booster'))
            for sd in dag[n][1][1]['sub_dags']:
                subgraph = to_pydot_graph(sd,
                                          sub_graph=True,
                                          input_edge=dag[n][0])
                graph.add_subgraph(subgraph)
        else:
            node = pydot.Node(node_name, label=label)
            graph.add_node(node)

    for (f, t) in dag_nx.edges():
        if f == 'input':
            if input_edge == None:
                f = 'input' + gid
            else:
                f = input_edge
        edge = pydot.Edge(f, t)
        graph.add_edge(edge)

    return graph
Beispiel #4
0
def model_to_dot(model,
                 show_shapes=False,
                 show_dtype=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False,
                 layer_range=None,
                 show_layer_activations=False):
    """Convert a Keras model to dot format.

  Args:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.
    layer_range: input of `list` containing two `str` items, which is the
        starting layer name and ending layer name (both inclusive) indicating
        the range of layers for which the `pydot.Dot` will be generated. It
        also accepts regex patterns instead of exact name. In such case, start
        predicate will be the first element it matches to `layer_range[0]`
        and the end predicate will be the last element it matches to
        `layer_range[1]`. By default `None` which considers all layers of
        model. Note that you must pass range such that the resultant subgraph
        must be complete.
    show_layer_activations: Display layer activations (only for layers that
        have an `activation` property).

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ValueError: if `model_to_dot` is called before the model is built.
    ImportError: if pydot is not available.
  """

    if not model.built:
        raise ValueError(
            'This model has not yet been built. '
            'Build the model first by calling `build()` or by calling '
            'the model on a batch of data.')

    from keras.layers import Wrapper
    from keras.engine import sequential
    from keras.engine import functional

    if not check_pydot():
        raise ImportError('You must install pydot (`pip install pydot`) for '
                          'model_to_dot to work.')

    if subgraph:
        dot = pydot.Cluster(style='dashed', graph_name=model.name)
        dot.set('label', model.name)
        dot.set('labeljust', 'l')
    else:
        dot = pydot.Dot()
        dot.set('rankdir', rankdir)
        dot.set('concentrate', True)
        dot.set('dpi', dpi)
        dot.set_node_defaults(shape='record')

    if layer_range is not None:
        if len(layer_range) != 2:
            raise ValueError(
                'layer_range must be of shape (2,). Received: '
                f'layer_range = {layer_range} of length {len(layer_range)}')
        if (not isinstance(layer_range[0], str)
                or not isinstance(layer_range[1], str)):
            raise ValueError('layer_range should contain string type only. '
                             f'Received: {layer_range}')
        layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
        if layer_range[0] < 0 or layer_range[1] > len(model.layers):
            raise ValueError(
                'Both values in layer_range should be in range (0, '
                f'{len(model.layers)}. Received: {layer_range}')

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    layers = model.layers
    if not model._is_graph_network:
        node = pydot.Node(str(id(model)), label=model.name)
        dot.add_node(node)
        return dot
    elif isinstance(model, sequential.Sequential):
        if not model.built:
            model.build()
        layers = super(sequential.Sequential, model).layers

    # Create graph nodes.
    for i, layer in enumerate(layers):
        if (layer_range) and (i < layer_range[0] or i > layer_range[1]):
            continue

        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__

        if isinstance(layer, Wrapper):
            if expand_nested and isinstance(layer.layer,
                                            functional.Functional):
                submodel_wrapper = model_to_dot(layer.layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = '{}({})'.format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = '{}({})'.format(class_name, child_class_name)

        if expand_nested and isinstance(layer, functional.Functional):
            submodel_not_wrapper = model_to_dot(layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.
        label = class_name

        # Rebuild the label as a table including the layer's activation.
        if (show_layer_activations and hasattr(layer, 'activation')
                and layer.activation is not None):
            if hasattr(layer.activation, 'name'):
                activation_name = layer.activation.name
            elif hasattr(layer.activation, '__name__'):
                activation_name = layer.activation.__name__
            else:
                activation_name = str(layer.activation)
            label = '{%s|%s}' % (label, activation_name)

        # Rebuild the label as a table including the layer's name.
        if show_layer_names:
            label = '%s|%s' % (layer_name, label)

        # Rebuild the label as a table including the layer's dtype.
        if show_dtype:

            def format_dtype(dtype):
                if dtype is None:
                    return '?'
                else:
                    return str(dtype)

            label = '%s|%s' % (label, format_dtype(layer.dtype))

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:

            def format_shape(shape):
                return str(shape).replace(str(None), 'None')

            try:
                outputlabels = format_shape(layer.output_shape)
            except AttributeError:
                outputlabels = '?'
            if hasattr(layer, 'input_shape'):
                inputlabels = format_shape(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [format_shape(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = '?'
            label = '{%s}|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
                                                           outputlabels)
        if not expand_nested or not isinstance(layer, functional.Functional):
            node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for i, layer in enumerate(layers):
        if (layer_range) and (i <= layer_range[0] or i > layer_range[1]):
            continue
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in tf.nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        add_edge(dot, inbound_layer_id, layer_id)
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if (not isinstance(inbound_layer,
                                           functional.Functional)
                                and not is_wrapped_model(inbound_layer)):
                            # if current layer is not Model or wrapped Model
                            if (not isinstance(layer, functional.Functional)
                                    and not is_wrapped_model(layer)):
                                assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                add_edge(dot, inbound_layer_id, layer_id)
                            # if current layer is Model
                            elif isinstance(layer, functional.Functional):
                                add_edge(
                                    dot, inbound_layer_id,
                                    sub_n_first_node[layer.name].get_name())
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                add_edge(dot, inbound_layer_id, layer_id)
                                name = sub_w_first_node[
                                    layer.layer.name].get_name()
                                add_edge(dot, layer_id, name)
                        # if inbound_layer is Model
                        elif isinstance(inbound_layer, functional.Functional):
                            name = sub_n_last_node[
                                inbound_layer.name].get_name()
                            if isinstance(layer, functional.Functional):
                                output_name = sub_n_first_node[
                                    layer.name].get_name()
                                add_edge(dot, name, output_name)
                            else:
                                add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            add_edge(
                                dot,
                                sub_w_last_node[inbound_layer_name].get_name(),
                                layer_id)
    return dot
Beispiel #5
0
def model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    show_classes=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    subgraph=False,
):
    """Convert a Keras model to dot format.
  Arguments:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    show_classes: whether to display layer class names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.
  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.
  Raises:
    ImportError: if graphviz or pydot are not available.
  """
    from tensorflow.python.keras.layers import wrappers
    from tensorflow.python.keras.engine import sequential
    from tensorflow.python.keras.engine import functional

    if not check_pydot():
        message = (
            "Failed to import pydot. You must `pip install pydot` "
            "and install graphviz (https://graphviz.gitlab.io/download/), ",
            "for `pydotprint` to work.",
        )
        if "IPython.core.magics.namespace" in sys.modules:
            # We don't raise an exception here in order to avoid crashing notebook
            # tests where graphviz is not available.
            print(message)
            return
        else:
            raise ImportError(message)

    if subgraph:
        dot = pydot.Cluster(style="dashed", graph_name=model.name)
        dot.set("label", model.name)
        dot.set("labeljust", "l")
    else:
        dot = pydot.Dot()
        dot.set("rankdir", rankdir)
        dot.set("concentrate", True)
        dot.set("dpi", dpi)
        dot.set_node_defaults(shape="record")

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    layers = model.layers
    if not model._is_graph_network:
        node = pydot.Node(str(id(model)), label=model.name)
        dot.add_node(node)
        return dot
    elif isinstance(model, sequential.Sequential):
        if not model.built:
            model.build()
        layers = super(sequential.Sequential, model).layers

    # Create graph nodes.
    for i, layer in enumerate(layers):
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name

        is_op = False
        if "tf_op_layer_" in layer_name:
            layer_name = layer_name.replace("tf_op_layer_", "")
            is_op = True
        class_name = layer.__class__.__name__

        if isinstance(layer, wrappers.Wrapper):
            if expand_nested and isinstance(layer.layer,
                                            functional.Functional):
                submodel_wrapper = model_to_dot(
                    layer.layer,
                    show_shapes,
                    show_dtype,
                    show_layer_names,
                    show_classes,
                    rankdir,
                    expand_nested,
                    subgraph=True,
                )
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = "{}({})".format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = "{}({})".format(class_name, child_class_name)

        if expand_nested and isinstance(layer, functional.Functional):
            submodel_not_wrapper = model_to_dot(
                layer,
                show_shapes,
                show_dtype,
                show_layer_names,
                show_classes,
                rankdir,
                expand_nested,
                subgraph=True,
            )
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.
        label = ""
        if show_layer_names:
            label += layer_name
        if show_classes and not is_op:
            label = f"{class_name}|{label}"

        # Rebuild the label as a table including the layer's dtype.
        if show_dtype:

            def format_dtype(dtype):
                if dtype is None:
                    return "?"
                else:
                    return str(dtype)

            label = "%s|%s" % (label, format_dtype(layer.dtype))

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:

            def format_shape(shape):
                return str(shape).replace(str(None), "None")

            try:
                outputlabels = format_shape(layer.output_shape)
            except AttributeError:
                outputlabels = "?"
            if hasattr(layer, "input_shape"):
                inputlabels = format_shape(layer.input_shape)
            elif hasattr(layer, "input_shapes"):
                inputlabels = ", ".join(
                    [format_shape(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = "?"
            # nodes_in: InputLayer\n|{input:|output:}|{{[(?, ?)]}|{[(?, ?)]}}
            label = "{%s|input:%s|output:%s}" % (
                label,
                inputlabels,
                outputlabels,
            )

        if not expand_nested or not isinstance(layer, functional.Functional):
            node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + "_ib-" + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        add_edge(dot, inbound_layer_id, layer_id)
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if not isinstance(
                                inbound_layer, functional.Functional
                        ) and not is_wrapped_model(inbound_layer):
                            # if current layer is not Model or wrapped Model
                            if not isinstance(
                                    layer, functional.Functional
                            ) and not is_wrapped_model(layer):
                                assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                add_edge(dot, inbound_layer_id, layer_id)
                            # if current layer is Model
                            elif isinstance(layer, functional.Functional):
                                add_edge(
                                    dot,
                                    inbound_layer_id,
                                    sub_n_first_node[layer.name].get_name(),
                                )
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                add_edge(dot, inbound_layer_id, layer_id)
                                name = sub_w_first_node[
                                    layer.layer.name].get_name()
                                add_edge(dot, layer_id, name)
                        # if inbound_layer is Model
                        elif isinstance(inbound_layer, functional.Functional):
                            name = sub_n_last_node[
                                inbound_layer.name].get_name()
                            if isinstance(layer, functional.Functional):
                                output_name = sub_n_first_node[
                                    layer.name].get_name()
                                add_edge(dot, name, output_name)
                            else:
                                add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            add_edge(
                                dot,
                                sub_w_last_node[inbound_layer_name].get_name(),
                                layer_id,
                            )
    return dot
Beispiel #6
0
def model_to_dot(model,
                 show_shapes=False,
                 show_dtype=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False):
    """Convert a Keras model to dot format.

  Arguments:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
    from tensorflow.python.keras.layers import wrappers
    from tensorflow.python.keras.engine import sequential
    from tensorflow.python.keras.engine import functional

    if not check_pydot():
        message = (
            'Failed to import pydot. You must `pip install pydot` '
            'and install graphviz (https://graphviz.gitlab.io/download/), ',
            'for `pydotprint` to work.')
        if 'IPython.core.magics.namespace' in sys.modules:
            # We don't raise an exception here in order to avoid crashing notebook
            # tests where graphviz is not available.
            print(message)
            return
        else:
            raise ImportError(message)

    if subgraph:
        dot = pydot.Cluster(style='dashed', graph_name=model.name)
        dot.set('label', model.name)
        dot.set('labeljust', 'l')
    else:
        dot = pydot.Dot()
        dot.set('rankdir', rankdir)
        dot.set('concentrate', True)
        dot.set('dpi', dpi)
        dot.set_node_defaults(shape='record')

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    layers = model.layers
    if not model._is_graph_network:
        node = pydot.Node(str(id(model)), label=model.name)
        dot.add_node(node)
        return dot
    elif isinstance(model, sequential.Sequential):
        if not model.built:
            model.build()
        layers = super(sequential.Sequential, model).layers

    # Create graph nodes.
    for i, layer in enumerate(layers):
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__

        if isinstance(layer, wrappers.Wrapper):
            if expand_nested and isinstance(layer.layer,
                                            functional.Functional):
                submodel_wrapper = model_to_dot(layer.layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = '{}({})'.format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = '{}({})'.format(class_name, child_class_name)

        if expand_nested and isinstance(layer, functional.Functional):
            submodel_not_wrapper = model_to_dot(layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including the layer's dtype.
        if show_dtype:

            def format_dtype(dtype):
                if dtype is None:
                    return '?'
                else:
                    return str(dtype)

            label = '%s|%s' % (label, format_dtype(layer.dtype))

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:

            def format_shape(shape):
                return str(shape).replace(str(None), 'None')

            try:
                outputlabels = format_shape(layer.output_shape)
            except AttributeError:
                outputlabels = '?'
            if hasattr(layer, 'input_shape'):
                inputlabels = format_shape(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [format_shape(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = '?'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
                                                           outputlabels)

        if not expand_nested or not isinstance(layer, functional.Functional):
            node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        add_edge(dot, inbound_layer_id, layer_id)
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if (not isinstance(inbound_layer,
                                           functional.Functional)
                                and not is_wrapped_model(inbound_layer)):
                            # if current layer is not Model or wrapped Model
                            if (not isinstance(layer, functional.Functional)
                                    and not is_wrapped_model(layer)):
                                assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                add_edge(dot, inbound_layer_id, layer_id)
                            # if current layer is Model
                            elif isinstance(layer, functional.Functional):
                                add_edge(
                                    dot, inbound_layer_id,
                                    sub_n_first_node[layer.name].get_name())
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                add_edge(dot, inbound_layer_id, layer_id)
                                name = sub_w_first_node[
                                    layer.layer.name].get_name()
                                add_edge(dot, layer_id, name)
                        # if inbound_layer is Model
                        elif isinstance(inbound_layer, functional.Functional):
                            name = sub_n_last_node[
                                inbound_layer.name].get_name()
                            if isinstance(layer, functional.Functional):
                                output_name = sub_n_first_node[
                                    layer.name].get_name()
                                add_edge(dot, name, output_name)
                            else:
                                add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            add_edge(
                                dot,
                                sub_w_last_node[inbound_layer_name].get_name(),
                                layer_id)
    return dot
    def __call__(self, fct, graph=None):
        """Create pydot graph from function.

        Parameters
        ----------
        fct : theano.compile.function_module.Function
            A compiled Theano function, variable, apply or a list of variables.
        graph: pydot.Dot
            `pydot` graph to which nodes are added. Creates new one if
            undefined.

        Returns
        -------
        pydot.Dot
            Pydot graph of `fct`
        """
        if graph is None:
            graph = pd.Dot()

        self.__nodes = {}

        profile = None
        if isinstance(fct, Function):
            mode = fct.maker.mode
            if (not isinstance(mode, ProfileMode)
                    or fct not in mode.profile_stats):
                mode = None
            if mode:
                profile = mode.profile_stats[fct]
            else:
                profile = getattr(fct, "profile", None)
            outputs = fct.maker.fgraph.outputs
            topo = fct.maker.fgraph.toposort()
        elif isinstance(fct, gof.FunctionGraph):
            outputs = fct.outputs
            topo = fct.toposort()
        else:
            if isinstance(fct, gof.Variable):
                fct = [fct]
            elif isinstance(fct, gof.Apply):
                fct = fct.outputs
            assert isinstance(fct, (list, tuple))
            assert all(isinstance(v, gof.Variable) for v in fct)
            fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct)
            outputs = fct.outputs
            topo = fct.toposort()
        outputs = list(outputs)

        # Loop over apply nodes
        for node in topo:
            nparams = {}
            __node_id = self.__node_id(node)
            nparams['name'] = __node_id
            nparams['label'] = apply_label(node)
            nparams['profile'] = apply_profile(node, profile)
            nparams['node_type'] = 'apply'
            nparams['apply_op'] = nparams['label']
            nparams['shape'] = self.shapes['apply']

            use_color = None
            for opName, color in iteritems(self.apply_colors):
                if opName in node.op.__class__.__name__:
                    use_color = color
            if use_color:
                nparams['style'] = 'filled'
                nparams['fillcolor'] = use_color
                nparams['type'] = 'colored'

            pd_node = dict_to_pdnode(nparams)
            graph.add_node(pd_node)

            # Loop over input nodes
            for id, var in enumerate(node.inputs):
                var_id = self.__node_id(var.owner if var.owner else var)
                if var.owner is None:
                    vparams = {
                        'name': var_id,
                        'label': var_label(var),
                        'node_type': 'input'
                    }
                    if isinstance(var, gof.Constant):
                        vparams['node_type'] = 'constant_input'
                    elif isinstance(
                            var, theano.tensor.sharedvar.TensorSharedVariable):
                        vparams['node_type'] = 'shared_input'
                    vparams['dtype'] = type_to_str(var.type)
                    vparams['tag'] = var_tag(var)
                    vparams['style'] = 'filled'
                    vparams['fillcolor'] = self.node_colors[
                        vparams['node_type']]
                    vparams['shape'] = self.shapes['input']
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                edge_params = {}
                if hasattr(node.op, 'view_map') and \
                        id in reduce(list.__add__,
                                     itervalues(node.op.view_map), []):
                    edge_params['color'] = self.node_colors['output']
                elif hasattr(node.op, 'destroy_map') and \
                        id in reduce(list.__add__,
                                     itervalues(node.op.destroy_map), []):
                    edge_params['color'] = 'red'

                edge_label = vparams['dtype']
                if len(node.inputs) > 1:
                    edge_label = str(id) + ' ' + edge_label
                pdedge = pd.Edge(var_id,
                                 __node_id,
                                 label=edge_label,
                                 **edge_params)
                graph.add_edge(pdedge)

            # Loop over output nodes
            for id, var in enumerate(node.outputs):
                var_id = self.__node_id(var)

                if var in outputs or len(var.clients) == 0:
                    vparams = {
                        'name': var_id,
                        'label': var_label(var),
                        'node_type': 'output',
                        'dtype': type_to_str(var.type),
                        'tag': var_tag(var),
                        'style': 'filled'
                    }
                    if len(var.clients) == 0:
                        vparams['fillcolor'] = self.node_colors['unused']
                    else:
                        vparams['fillcolor'] = self.node_colors['output']
                    vparams['shape'] = self.shapes['output']
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                    graph.add_edge(
                        pd.Edge(__node_id, var_id, label=vparams['dtype']))
                elif var.name or not self.compact:
                    graph.add_edge(
                        pd.Edge(__node_id, var_id, label=vparams['dtype']))

            # Create sub-graph for OpFromGraph nodes
            if isinstance(node.op, builders.OpFromGraph):
                subgraph = pd.Cluster(__node_id)
                gf = PyDotFormatter()
                # Use different node prefix for sub-graphs
                gf.__node_prefix = __node_id
                gf(node.op.fn, subgraph)
                graph.add_subgraph(subgraph)
                pd_node.get_attributes()['subg'] = subgraph.get_name()

                def format_map(m):
                    return str([list(x) for x in m])

                # Inputs mapping
                ext_inputs = [self.__node_id(x) for x in node.inputs]
                int_inputs = [
                    gf.__node_id(x) for x in node.op.fn.maker.fgraph.inputs
                ]
                assert len(ext_inputs) == len(int_inputs)
                h = format_map(zip(ext_inputs, int_inputs))
                pd_node.get_attributes()['subg_map_inputs'] = h

                # Outputs mapping
                ext_outputs = []
                for n in topo:
                    for i in n.inputs:
                        h = i.owner if i.owner else i
                        if h is node:
                            ext_outputs.append(self.__node_id(n))
                int_outputs = node.op.fn.maker.fgraph.outputs
                int_outputs = [gf.__node_id(x) for x in int_outputs]
                assert len(ext_outputs) == len(int_outputs)
                h = format_map(zip(int_outputs, ext_outputs))
                pd_node.get_attributes()['subg_map_outputs'] = h

        return graph
Beispiel #8
0
def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False):
    """Convert a Keras model to dot format.

    # Arguments
        model: A Keras model instance.
        show_shapes: whether to display shape information.
        show_layer_names: whether to display layer names.
        rankdir: `rankdir` argument passed to PyDot,
            a string specifying the format of the plot:
            'TB' creates a vertical plot;
            'LR' creates a horizontal plot.
        expand_nested: whether to expand nested models into clusters.
        dpi: dot DPI.
        subgraph: whether to return a pydot.Cluster instance.

    # Returns
        A `pydot.Dot` instance representing the Keras model or
        a `pydot.Cluster` instance representing nested model if
        `subgraph=True`.
    """
    from keras.layers.wrappers import Wrapper
    from keras.models import Model
    from keras.models import Sequential

    _check_pydot()
    if subgraph:
        dot = pydot.Cluster(style='dashed', graph_name=model.name)
        dot.set('label', model.name)
        dot.set('labeljust', 'l')
    else:
        dot = pydot.Dot()
        dot.set('rankdir', rankdir)
        dot.set('concentrate', True)
        dot.set('dpi', dpi)
        dot.set_node_defaults(shape='record')

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
    layers = model._layers

    # Create graph nodes.
    for i, layer in enumerate(layers):
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__

        if isinstance(layer, Wrapper):
            if expand_nested and isinstance(layer.layer, Model):
                submodel_wrapper = model_to_dot(layer.layer, show_shapes,
                                                show_layer_names, rankdir,
                                                expand_nested,
                                                subgraph=True)
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = '{}({})'.format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = '{}({})'.format(class_name, child_class_name)

        if expand_nested and isinstance(layer, Model):
            submodel_not_wrapper = model_to_dot(layer, show_shapes,
                                                show_layer_names, rankdir,
                                                expand_nested,
                                                subgraph=True)
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:
            try:
                outputlabels = str(layer.output_shape)
            except AttributeError:
                outputlabels = 'multiple'
            if hasattr(layer, 'input_shape'):
                inputlabels = str(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [str(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = 'multiple'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                           inputlabels,
                                                           outputlabels)

        if not expand_nested or not isinstance(layer, Model):
            node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if True or node_key in model._network_nodes:
                for inbound_layer in node.inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if not is_model(inbound_layer) and (
                                not is_wrapped_model(inbound_layer)):
                            # if current layer is not Model or wrapped Model
                            if not is_model(layer) and (
                                    not is_wrapped_model(layer)):
                                # assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                dot.add_edge(pydot.Edge(inbound_layer_id,
                                                        layer_id))
                            # if current layer is Model
                            elif is_model(layer):
                                add_edge(dot, inbound_layer_id,
                                         sub_n_first_node[layer.name].get_name())
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                dot.add_edge(pydot.Edge(inbound_layer_id,
                                                        layer_id))
                                name = sub_w_first_node[layer.layer.name].get_name()
                                dot.add_edge(pydot.Edge(layer_id,
                                                        name))
                        # if inbound_layer is Model
                        elif is_model(inbound_layer):
                            name = sub_n_last_node[inbound_layer.name].get_name()
                            if is_model(layer):
                                output_name = sub_n_first_node[layer.name].get_name()
                                add_edge(dot, name, output_name)
                            else:
                                add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            add_edge(dot,
                                     sub_w_last_node[inbound_layer_name].get_name(),
                                     layer_id)
    return dot
Beispiel #9
0
def pydotprint(
    fct,
    outfile=None,
    compact=True,
    format="png",
    with_ids=False,
    high_contrast=True,
    cond_highlight=None,
    colorCodes=None,
    max_label_size=70,
    scan_graphs=False,
    var_with_name_simple=False,
    print_output_file=True,
    return_image=False,
):
    """Print to a file the graph of a compiled aesara function's ops. Supports
    all pydot output formats, including png and svg.

    :param fct: a compiled Aesara function, a Variable, an Apply or
                a list of Variable.
    :param outfile: the output file where to put the graph.
    :param compact: if True, will remove intermediate var that don't have name.
    :param format: the file format of the output.
    :param with_ids: Print the toposort index of the node in the node name.
                     and an index number in the variable ellipse.
    :param high_contrast: if true, the color that describes the respective
            node is filled with its corresponding color, instead of coloring
            the border
    :param colorCodes: dictionary with names of ops as keys and colors as
            values
    :param cond_highlight: Highlights a lazy if by surrounding each of the 3
                possible categories of ops with a border. The categories
                are: ops that are on the left branch, ops that are on the
                right branch, ops that are on both branches
                As an alternative you can provide the node that represents
                the lazy if
    :param scan_graphs: if true it will plot the inner graph of each scan op
                in files with the same name as the name given for the main
                file to which the name of the scan op is concatenated and
                the index in the toposort of the scan.
                This index can be printed with the option with_ids.
    :param var_with_name_simple: If true and a variable have a name,
                we will print only the variable name.
                Otherwise, we concatenate the type to the var name.
    :param return_image: If True, it will create the image and return it.
        Useful to display the image in ipython notebook.

        .. code-block:: python

            import aesara
            v = aesara.tensor.vector()
            from IPython.display import SVG
            SVG(aesara.printing.pydotprint(v*2, return_image=True,
                                           format='svg'))

    In the graph, ellipses are Apply Nodes (the execution of an op)
    and boxes are variables.  If variables have names they are used as
    text (if multiple vars have the same name, they will be merged in
    the graph).  Otherwise, if the variable is constant, we print its
    value and finally we print the type + a unique number to prevent
    multiple vars from being merged.  We print the op of the apply in
    the Apply box with a number that represents the toposort order of
    application of those Apply.  If an Apply has more than 1 input, we
    label each edge between an input and the Apply node with the
    input's index.

    Variable color code::
        - Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
        - Green boxes are inputs variables to the graph,
        - Blue boxes are outputs variables of the graph,
        - Grey boxes are variables that are not outputs and are not used,

    Default apply node code::
        - Red ellipses are transfers from/to the gpu
        - Yellow are scan node
        - Brown are shape node
        - Magenta are IfElse node
        - Dark pink are elemwise node
        - Purple are subtensor
        - Orange are alloc node

    For edges, they are black by default. If a node returns a view
    of an input, we put the corresponding input edge in blue. If it
    returns a destroyed input, we put the corresponding edge in red.

    .. note::

        Since October 20th, 2014, this print the inner function of all
        scan separately after the top level debugprint output.

    """
    from aesara.scan.op import Scan

    if colorCodes is None:
        colorCodes = default_colorCodes

    if outfile is None:
        outfile = os.path.join(
            config.compiledir,
            "aesara.pydotprint." + config.device + "." + format)

    if isinstance(fct, Function):
        profile = getattr(fct, "profile", None)
        fgraph = fct.maker.fgraph
        outputs = fgraph.outputs
        topo = fgraph.toposort()
    elif isinstance(fct, FunctionGraph):
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    else:
        if isinstance(fct, Variable):
            fct = [fct]
        elif isinstance(fct, Apply):
            fct = fct.outputs
        assert isinstance(fct, (list, tuple))
        assert all(isinstance(v, Variable) for v in fct)
        fct = FunctionGraph(inputs=list(graph_inputs(fct)), outputs=fct)
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    if not pydot_imported:
        raise RuntimeError(
            "Failed to import pydot. You must install graphviz"
            " and either pydot or pydot-ng for "
            "`pydotprint` to work.",
            pydot_imported_msg,
        )

    g = pd.Dot()

    if cond_highlight is not None:
        c1 = pd.Cluster("Left")
        c2 = pd.Cluster("Right")
        c3 = pd.Cluster("Middle")
        cond = None
        for node in topo:
            if (node.op.__class__.__name__ == "IfElse"
                    and node.op.name == cond_highlight):
                cond = node
        if cond is None:
            _logger.warning("pydotprint: cond_highlight is set but there is no"
                            " IfElse node in the graph")
            cond_highlight = None

    if cond_highlight is not None:

        def recursive_pass(x, ls):
            if not x.owner:
                return ls
            else:
                ls += [x.owner]
                for inp in x.inputs:
                    ls += recursive_pass(inp, ls)
                return ls

        left = set(recursive_pass(cond.inputs[1], []))
        right = set(recursive_pass(cond.inputs[2], []))
        middle = left.intersection(right)
        left = left.difference(middle)
        right = right.difference(middle)
        middle = list(middle)
        left = list(left)
        right = list(right)

    var_str = {}
    var_id = {}
    all_strings = set()

    def var_name(var):
        if var in var_str:
            return var_str[var], var_id[var]

        if var.name is not None:
            if var_with_name_simple:
                varstr = var.name
            else:
                varstr = "name=" + var.name + " " + str(var.type)
        elif isinstance(var, Constant):
            dstr = "val=" + str(np.asarray(var.data))
            if "\n" in dstr:
                dstr = dstr[:dstr.index("\n")]
            varstr = f"{dstr} {var.type}"
        elif var in input_update and input_update[var].name is not None:
            varstr = input_update[var].name
            if not var_with_name_simple:
                varstr += str(var.type)
        else:
            # a var id is needed as otherwise var with the same type will be
            # merged in the graph.
            varstr = str(var.type)
        if len(varstr) > max_label_size:
            varstr = varstr[:max_label_size - 3] + "..."
        var_str[var] = varstr
        var_id[var] = str(id(var))

        all_strings.add(varstr)

        return varstr, var_id[var]

    apply_name_cache = {}
    apply_name_id = {}

    def apply_name(node):
        if node in apply_name_cache:
            return apply_name_cache[node], apply_name_id[node]
        prof_str = ""
        if profile:
            time = profile.apply_time.get((fgraph, node), 0)
            # second, %fct time in profiler
            if profile.fct_callcount == 0 or profile.fct_call_time == 0:
                pf = 0
            else:
                pf = time * 100 / profile.fct_call_time
            prof_str = f"   ({time:.3f}s,{pf:.3f}%)"
        applystr = str(node.op).replace(":", "_")
        applystr += prof_str
        if (applystr in all_strings) or with_ids:
            idx = " id=" + str(topo.index(node))
            if len(applystr) + len(idx) > max_label_size:
                applystr = applystr[:max_label_size - 3 -
                                    len(idx)] + idx + "..."
            else:
                applystr = applystr + idx
        elif len(applystr) > max_label_size:
            applystr = applystr[:max_label_size - 3] + "..."
            idx = 1
            while applystr in all_strings:
                idx += 1
                suffix = " id=" + str(idx)
                applystr = applystr[:max_label_size - 3 -
                                    len(suffix)] + "..." + suffix

        all_strings.add(applystr)
        apply_name_cache[node] = applystr
        apply_name_id[node] = str(id(node))

        return applystr, apply_name_id[node]

    # Update the inputs that have an update function
    input_update = {}
    reverse_input_update = {}
    # Here outputs can be the original list, as we should not change
    # it, we must copy it.
    outputs = list(outputs)
    if isinstance(fct, Function):
        function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
        for i, fg_ii in reversed(list(function_inputs)):
            if i.update is not None:
                k = outputs.pop()
                # Use the fgaph.inputs as it isn't the same as maker.inputs
                input_update[k] = fg_ii
                reverse_input_update[fg_ii] = k

    apply_shape = "ellipse"
    var_shape = "box"
    for node_idx, node in enumerate(topo):
        astr, aid = apply_name(node)

        use_color = None
        for opName, color in colorCodes.items():
            if opName in node.op.__class__.__name__:
                use_color = color

        if use_color is None:
            nw_node = pd.Node(aid, label=astr, shape=apply_shape)
        elif high_contrast:
            nw_node = pd.Node(aid,
                              label=astr,
                              style="filled",
                              fillcolor=use_color,
                              shape=apply_shape)
        else:
            nw_node = pd.Node(aid,
                              label=astr,
                              color=use_color,
                              shape=apply_shape)
        g.add_node(nw_node)
        if cond_highlight:
            if node in middle:
                c3.add_node(nw_node)
            elif node in left:
                c1.add_node(nw_node)
            elif node in right:
                c2.add_node(nw_node)

        for idx, var in enumerate(node.inputs):
            varstr, varid = var_name(var)
            label = ""
            if len(node.inputs) > 1:
                label = str(idx)
            param = {}
            if label:
                param["label"] = label
            if hasattr(node.op, "view_map") and idx in reduce(
                    list.__add__, node.op.view_map.values(), []):
                param["color"] = colorCodes["Output"]
            elif hasattr(node.op, "destroy_map") and idx in reduce(
                    list.__add__, node.op.destroy_map.values(), []):
                param["color"] = "red"
            if var.owner is None:
                color = "green"
                if isinstance(var, SharedVariable):
                    # Input are green, output blue
                    # Mixing blue and green give cyan! (input and output var)
                    color = "cyan"
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            fillcolor=color,
                            label=varstr,
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        pd.Node(varid,
                                color=color,
                                label=varstr,
                                shape=var_shape))
                g.add_edge(pd.Edge(varid, aid, **param))
            elif var.name or not compact or var in outputs:
                g.add_edge(pd.Edge(varid, aid, **param))
            else:
                # no name, so we don't make a var ellipse
                if label:
                    label += " "
                label += str(var.type)
                if len(label) > max_label_size:
                    label = label[:max_label_size - 3] + "..."
                param["label"] = label
                g.add_edge(pd.Edge(apply_name(var.owner)[1], aid, **param))

        for idx, var in enumerate(node.outputs):
            varstr, varid = var_name(var)
            out = var in outputs
            label = ""
            if len(node.outputs) > 1:
                label = str(idx)
            if len(label) > max_label_size:
                label = label[:max_label_size - 3] + "..."
            param = {}
            if label:
                param["label"] = label
            if out or var in input_update:
                g.add_edge(pd.Edge(aid, varid, **param))
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor=colorCodes["Output"],
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        pd.Node(
                            varid,
                            color=colorCodes["Output"],
                            label=varstr,
                            shape=var_shape,
                        ))
            elif len(fgraph.clients[var]) == 0:
                g.add_edge(pd.Edge(aid, varid, **param))
                # grey mean that output var isn't used
                if high_contrast:
                    g.add_node(
                        pd.Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor="grey",
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        pd.Node(varid,
                                label=varstr,
                                color="grey",
                                shape=var_shape))
            elif var.name or not compact:
                if not (not compact):
                    if label:
                        label += " "
                    label += str(var.type)
                    if len(label) > max_label_size:
                        label = label[:max_label_size - 3] + "..."
                    param["label"] = label
                g.add_edge(pd.Edge(aid, varid, **param))
                g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
    #            else:
    # don't add egde here as it is already added from the inputs.

    # The var that represent updates, must be linked to the input var.
    for sha, up in input_update.items():
        _, shaid = var_name(sha)
        _, upid = var_name(up)
        g.add_edge(
            pd.Edge(shaid, upid, label="UPDATE", color=colorCodes["Output"]))

    if cond_highlight:
        g.add_subgraph(c1)
        g.add_subgraph(c2)
        g.add_subgraph(c3)

    if not outfile.endswith("." + format):
        outfile += "." + format

    if scan_graphs:
        scan_ops = [(idx, x) for idx, x in enumerate(topo)
                    if isinstance(x.op, Scan)]
        path, fn = os.path.split(outfile)
        basename = ".".join(fn.split(".")[:-1])
        # Safe way of doing things .. a file name may contain multiple .
        ext = fn[len(basename):]

        for idx, scan_op in scan_ops:
            # is there a chance that name is not defined?
            if hasattr(scan_op.op, "name"):
                new_name = basename + "_" + scan_op.op.name + "_" + str(idx)
            else:
                new_name = basename + "_" + str(idx)
            new_name = os.path.join(path, new_name + ext)
            if hasattr(scan_op.op, "fn"):
                to_print = scan_op.op.fn
            else:
                to_print = scan_op.op.outputs
            pydotprint(
                to_print,
                new_name,
                compact,
                format,
                with_ids,
                high_contrast,
                cond_highlight,
                colorCodes,
                max_label_size,
                scan_graphs,
            )

    if return_image:
        return g.create(prog="dot", format=format)
    else:
        try:
            g.write(outfile, prog="dot", format=format)
        except pd.InvocationException:
            # based on https://github.com/Theano/Theano/issues/2988
            version = getattr(pd, "__version__", "")
            if version and [int(n) for n in version.split(".")] < [1, 0, 28]:
                raise Exception("Old version of pydot detected, which can "
                                "cause issues with pydot printing. Try "
                                "upgrading pydot version to a newer one")
            raise

        if print_output_file:
            print("The output file is available at", outfile)
Beispiel #10
0
def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 style=0,
                 color=True,
                 subgraph=False):
    """Convert a Keras model to dot format.

  Arguments:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    style: value 0,1.
    color: whether to display color.
    subgraph: whether to return a `pydot.Cluster` instance.

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
    from tensorflow.python.keras.layers import wrappers
    from tensorflow.python.keras.engine import sequential
    from tensorflow.python.keras.engine import network

    if not check_pydot():
        if 'IPython.core.magics.namespace' in sys.modules:
            # We don't raise an exception here in order to avoid crashing notebook
            # tests where graphviz is not available.
            print('Failed to import pydot. You must install pydot'
                  ' and graphviz for `pydotprint` to work.')
            return
        else:
            raise ImportError('Failed to import pydot. You must install pydot'
                              ' and graphviz for `pydotprint` to work.')

    if subgraph:
        dot = pydot.Cluster(style='dashed', graph_name=model.name)
        dot.set('label', model.name)
        dot.set('labeljust', 'l')
    else:
        dot = pydot.Dot()
        dot.set('rankdir', rankdir)
        dot.set('concentrate', True)
        dot.set('dpi', dpi)
        dot.set_node_defaults(shape='record')

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    if not model._is_graph_network:
        node = pydot.Node(str(id(model)), label=model.name)
        dot.add_node(node)
        return dot
    elif isinstance(model, sequential.Sequential):
        if not model.built:
            model.build()
    layers = model._layers
    num_layers = len(layers)

    # Create graph nodes.
    for i, layer in enumerate(layers):
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        class_name_lower = class_name.lower()
        config = 0
        try:
            config = layer.get_config()
        except:
            pass

        if isinstance(layer, wrappers.Wrapper):
            if expand_nested and isinstance(layer.layer, network.Network):
                submodel_wrapper = model_to_dot(layer.layer,
                                                show_shapes,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = '{}({})'.format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = '{}({})'.format(class_name, child_class_name)

        if expand_nested and isinstance(layer, network.Network):
            submodel_not_wrapper = model_to_dot(layer,
                                                show_shapes,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.

        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
            inputs = re.compile('input')
            if inputs.findall(class_name_lower):
                pass
            else:
                if config != 0:
                    conv = re.compile('conv')
                    if conv.findall(class_name_lower):
                        label = '{}:{},{}|kernel:{}  strides:{}'.format(
                            layer_name, class_name, config['padding'],
                            config['kernel_size'], config['strides'])
                    pool = re.compile('pool')
                    if pool.findall(class_name_lower
                                    ) and class_name_lower[:6] != 'global':
                        label = '{}:{},{}|kernel:{}  strides:{}'.format(
                            layer_name, class_name, config['padding'],
                            config['pool_size'], config['strides'])
                    activation = re.compile('activation')
                    if activation.findall(class_name_lower):
                        label = '{}:{}|{}'.format(layer_name, class_name,
                                                  config['activation'])
                    dropout = re.compile('dropout')
                    if dropout.findall(class_name_lower):
                        label = '{}:{}|{}'.format(layer_name, class_name,
                                                  config['rate'])
                    dense = re.compile('dense')
                    if dense.findall(class_name_lower):
                        label = '{}:{}|{}'.format(layer_name, class_name,
                                                  config['activation'])

        else:
            label = '{}'.format(class_name)
            inputs = re.compile('input')
            if inputs.findall(class_name_lower):
                pass
            else:
                if config != 0:
                    conv = re.compile('conv')
                    if conv.findall(class_name_lower):
                        label = '{},{}|kernel:{}  strides:{}'.format(
                            class_name, config['padding'],
                            config['kernel_size'], config['strides'])
                    pool = re.compile('pool')
                    if pool.findall(class_name_lower
                                    ) and class_name_lower[:6] != 'global':
                        label = '{},{}|kernel:{}  strides:{}'.format(
                            class_name, config['padding'], config['pool_size'],
                            config['strides'])
                    activation = re.compile('activation')
                    if activation.findall(class_name_lower):
                        label = '{}|{}'.format(class_name,
                                               config['activation'])
                    dropout = re.compile('dropout')
                    if dropout.findall(class_name_lower):
                        label = '{}|{}'.format(class_name, config['rate'])
                    dense = re.compile('dense')
                    if dense.findall(class_name_lower):
                        label = '{}|{}'.format(class_name,
                                               config['activation'])

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:

            def format_shape(shape):
                return str(shape).replace(str(None), '?')

            try:
                outputlabels = format_shape(layer.output_shape)
            except AttributeError:
                outputlabels = '?'
            if hasattr(layer, 'input_shape'):
                inputlabels = format_shape(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [format_shape(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = '?'

            if style == 0:
                inputs = re.compile('input')
                if inputs.findall(class_name_lower):
                    label = '{%s}|{input:}|{%s}' % (label, inputlabels)
                else:
                    for i, node in enumerate(layer._inbound_nodes):
                        for outbound_layer in nest.flatten(
                                node.outbound_layer):
                            if outbound_layer.outbound_nodes == []:
                                label = '{%s}|{output:}|{%s}' % (label,
                                                                 outputlabels)
                            else:
                                label = '{%s}' % (label)
            elif style == 1:
                label = '{%s}|{input:|output:}|{{%s}|{%s}}' % (
                    label, inputlabels, outputlabels)

        if not expand_nested or not isinstance(layer, network.Network):
            if color == True:
                inputs = re.compile('input')
                conv = re.compile('conv')
                pool = re.compile('pool')
                normalization = re.compile('normalization')
                activation = re.compile('activation')
                dropout = re.compile('dropout')
                dense = re.compile('dense')
                padding = re.compile('padding')
                concatenate = re.compile('concatenate')
                rnn = re.compile('rnn')
                lstm = re.compile('lstm')
                gru = re.compile('gru')
                bidirectional = re.compile('bidirectional')
                if inputs.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='deeppink',
                                      style="filled")
                elif conv.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='cyan',
                                      style="filled")
                elif pool.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='chartreuse',
                                      style="filled")
                elif normalization.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='dodgerblue1',
                                      style="filled")
                elif activation.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='pink',
                                      style="filled")
                elif dropout.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='darkorange',
                                      style="filled")
                elif dense.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='darkorchid1',
                                      style="filled")
                elif padding.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='beige',
                                      style="filled")
                elif concatenate.findall(class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='tomato',
                                      style="filled")
                elif rnn.findall(class_name_lower) or lstm.findall(
                        class_name_lower) or gru.findall(
                            class_name_lower) or bidirectional.findall(
                                class_name_lower):
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='yellow1',
                                      style="filled")
                else:
                    node = pydot.Node(layer_id,
                                      label=label,
                                      fillcolor='gold',
                                      style="filled")
            else:
                node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for j, layer in enumerate(layers):
        # print(layer)
        # print(layer.output_shape)
        def format_shape(shape):
            return str(shape).replace(str(None), '?')

        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        if style == 0:
                            try:
                                add_edge(
                                    dot, inbound_layer_id, layer_id,
                                    format_shape(inbound_layer.output_shape))
                            except:
                                add_edge(dot, inbound_layer_id, layer_id, '?')
                        elif style == 1:
                            add_edge(dot, inbound_layer_id, layer_id)
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if (not isinstance(inbound_layer, network.Network)
                                and not is_wrapped_model(inbound_layer)):
                            # if current layer is not Model or wrapped Model
                            if (not isinstance(layer, network.Network)
                                    and not is_wrapped_model(layer)):
                                assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                if style == 0:
                                    try:
                                        add_edge(
                                            dot, inbound_layer_id, layer_id,
                                            format_shape(
                                                inbound_layer.output_shape))
                                    except:
                                        add_edge(dot, inbound_layer_id,
                                                 layer_id, '?')
                                elif style == 1:
                                    add_edge(dot, inbound_layer_id, layer_id)
                            # if current layer is Model
                            elif isinstance(layer, network.Network):
                                if style == 0:
                                    add_edge(
                                        dot, inbound_layer_id,
                                        sub_n_first_node[
                                            layer.name].get_name(),
                                        format_shape(
                                            inbound_layer.output_shape))
                                elif style == 1:
                                    add_edge(
                                        dot, inbound_layer_id,
                                        sub_n_first_node[
                                            layer.name].get_name())
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                if style == 0:
                                    try:
                                        add_edge(
                                            dot, inbound_layer_id, layer_id,
                                            format_shape(
                                                inbound_layer.output_shape))
                                    except:
                                        add_edge(dot, inbound_layer_id,
                                                 layer_id, '?')
                                    name = sub_w_first_node[
                                        layer.layer.name].get_name()
                                    add_edge(dot, layer_id, name,
                                             format_shape(layer.output_shape))
                                elif style == 1:
                                    add_edge(dot, inbound_layer_id, layer_id)
                                    name = sub_w_first_node[
                                        layer.layer.name].get_name()
                                    add_edge(dot, layer_id, name)
                        # if inbound_layer is Model
                        elif isinstance(inbound_layer, network.Network):
                            name = sub_n_last_node[
                                inbound_layer.name].get_name()
                            if isinstance(layer, network.Network):
                                output_name = sub_n_first_node[
                                    layer.name].get_name()
                                if style == 0:
                                    try:
                                        add_edge(
                                            dot, name, output_name,
                                            format_shape(layer.output_shape))
                                    except:
                                        add_edge(dot, name, output_name, '?')
                                elif style == 1:
                                    add_edge(dot, name, output_name)
                            else:
                                if style == 0:
                                    try:
                                        add_edge(
                                            dot, name, layer_id,
                                            format_shape(layer.output_shape))
                                    except:
                                        add_edge(dot, name, layer_id, '?')
                                elif style == 1:
                                    add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            if style == 0:
                                try:
                                    add_edge(
                                        dot,
                                        sub_w_last_node[inbound_layer_name].
                                        get_name(), layer_id,
                                        format_shape(
                                            inbound_layer.output_shape))
                                except:
                                    add_edge(
                                        dot,
                                        sub_w_last_node[inbound_layer_name].
                                        get_name(), layer_id, '?')
                            elif style == 1:
                                add_edge(
                                    dot, sub_w_last_node[inbound_layer_name].
                                    get_name(), layer_id)

    return dot
Beispiel #11
0
def model_to_dot(model,
                 show_shapes=False,
                 show_dtype=False,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=96,
                 subgraph=False,
                 layer_range=None):
    """Convert a Keras model to dot format.

  Args:
    model: A Keras model instance.
    show_shapes: whether to display shape information.
    show_dtype: whether to display layer dtypes.
    show_layer_names: whether to display layer names.
    rankdir: `rankdir` argument passed to PyDot,
        a string specifying the format of the plot:
        'TB' creates a vertical plot;
        'LR' creates a horizontal plot.
    expand_nested: whether to expand nested models into clusters.
    dpi: Dots per inch.
    subgraph: whether to return a `pydot.Cluster` instance.
    layer_range: input of `list` containing two `str` items, which is the
        starting layer name and ending layer name (both inclusive) indicating
        the range of layers for which the `pydot.Dot` will be generated. It
        also accepts regex patterns instead of exact name. In such case, start
        predicate will be the first element it matches to `layer_range[0]`
        and the end predicate will be the last element it matches to
        `layer_range[1]`. By default `None` which considers all layers of
        model. Note that you must pass range such that the resultant subgraph
        must be complete.

  Returns:
    A `pydot.Dot` instance representing the Keras model or
    a `pydot.Cluster` instance representing nested model if
    `subgraph=True`.

  Raises:
    ImportError: if graphviz or pydot are not available.
  """
    from keras.layers import wrappers
    from keras.engine import sequential
    from keras.engine import functional

    if not check_pydot():
        message = (
            'You must install pydot (`pip install pydot`) '
            'and install graphviz '
            '(see instructions at https://graphviz.gitlab.io/download/) ',
            'for plot_model/model_to_dot to work.')
        if 'IPython.core.magics.namespace' in sys.modules:
            # We don't raise an exception here in order to avoid crashing notebook
            # tests where graphviz is not available.
            print(message)
            return
        else:
            raise ImportError(message)

    if subgraph:
        dot = pydot.Cluster(style='dashed', graph_name=model.name)
        dot.set('label', model.name)
        dot.set('labeljust', 'l')
    else:
        dot = pydot.Dot()
        dot.set('rankdir', rankdir)
        dot.set('concentrate', True)
        dot.set('dpi', dpi)
        dot.set_node_defaults(shape='record')

    if layer_range:
        if len(layer_range) != 2:
            raise ValueError('layer_range must be of shape (2,)')
        if (not isinstance(layer_range[0], str)
                or not isinstance(layer_range[1], str)):
            raise ValueError('layer_range should contain string type only')
        layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
        if layer_range[0] < 0 or layer_range[1] > len(model.layers):
            raise ValueError('Both values in layer_range should be in',
                             'range (%d, %d)' % (0, len(model.layers)))

    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    layers = model.layers
    if not model._is_graph_network:
        node = pydot.Node(str(id(model)), label=model.name)
        dot.add_node(node)
        return dot
    elif isinstance(model, sequential.Sequential):
        if not model.built:
            model.build()
        layers = super(sequential.Sequential, model).layers

    # Create graph nodes.
    for i, layer in enumerate(layers):
        if (layer_range) and (i < layer_range[0] or i > layer_range[1]):
            continue

        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__

        if isinstance(layer, wrappers.Wrapper):
            if expand_nested and isinstance(layer.layer,
                                            functional.Functional):
                submodel_wrapper = model_to_dot(layer.layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
                # sub_w : submodel_wrapper
                sub_w_nodes = submodel_wrapper.get_nodes()
                sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
                sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
                dot.add_subgraph(submodel_wrapper)
            else:
                layer_name = '{}({})'.format(layer_name, layer.layer.name)
                child_class_name = layer.layer.__class__.__name__
                class_name = '{}({})'.format(class_name, child_class_name)

        if expand_nested and isinstance(layer, functional.Functional):
            submodel_not_wrapper = model_to_dot(layer,
                                                show_shapes,
                                                show_dtype,
                                                show_layer_names,
                                                rankdir,
                                                expand_nested,
                                                subgraph=True)
            # sub_n : submodel_not_wrapper
            sub_n_nodes = submodel_not_wrapper.get_nodes()
            sub_n_first_node[layer.name] = sub_n_nodes[0]
            sub_n_last_node[layer.name] = sub_n_nodes[-1]
            dot.add_subgraph(submodel_not_wrapper)

        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including the layer's dtype.
        if show_dtype:

            def format_dtype(dtype):
                if dtype is None:
                    return '?'
                else:
                    return str(dtype)

            label = '%s|%s' % (label, format_dtype(layer.dtype))

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:

            def format_shape(shape):
                return str(shape).replace(str(None), 'None')

            try:
                outputlabels = format_shape(layer.output_shape)
            except AttributeError:
                outputlabels = '?'
            if hasattr(layer, 'input_shape'):
                inputlabels = format_shape(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [format_shape(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = '?'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
                                                           outputlabels)

        if not expand_nested or not isinstance(layer, functional.Functional):
            node = pydot.Node(layer_id, label=label)
            dot.add_node(node)

    # Connect nodes with edges.
    for i, layer in enumerate(layers):
        if (layer_range) and (i <= layer_range[0] or i > layer_range[1]):
            continue
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in tf.nest.flatten(node.inbound_layers):
                    inbound_layer_id = str(id(inbound_layer))
                    if not expand_nested:
                        assert dot.get_node(inbound_layer_id)
                        assert dot.get_node(layer_id)
                        add_edge(dot, inbound_layer_id, layer_id)
                    else:
                        # if inbound_layer is not Model or wrapped Model
                        if (not isinstance(inbound_layer,
                                           functional.Functional)
                                and not is_wrapped_model(inbound_layer)):
                            # if current layer is not Model or wrapped Model
                            if (not isinstance(layer, functional.Functional)
                                    and not is_wrapped_model(layer)):
                                assert dot.get_node(inbound_layer_id)
                                assert dot.get_node(layer_id)
                                add_edge(dot, inbound_layer_id, layer_id)
                            # if current layer is Model
                            elif isinstance(layer, functional.Functional):
                                add_edge(
                                    dot, inbound_layer_id,
                                    sub_n_first_node[layer.name].get_name())
                            # if current layer is wrapped Model
                            elif is_wrapped_model(layer):
                                add_edge(dot, inbound_layer_id, layer_id)
                                name = sub_w_first_node[
                                    layer.layer.name].get_name()
                                add_edge(dot, layer_id, name)
                        # if inbound_layer is Model
                        elif isinstance(inbound_layer, functional.Functional):
                            name = sub_n_last_node[
                                inbound_layer.name].get_name()
                            if isinstance(layer, functional.Functional):
                                output_name = sub_n_first_node[
                                    layer.name].get_name()
                                add_edge(dot, name, output_name)
                            else:
                                add_edge(dot, name, layer_id)
                        # if inbound_layer is wrapped Model
                        elif is_wrapped_model(inbound_layer):
                            inbound_layer_name = inbound_layer.layer.name
                            add_edge(
                                dot,
                                sub_w_last_node[inbound_layer_name].get_name(),
                                layer_id)
    return dot