示例#1
0
class MyiaGraphPrinter(GraphPrinter):
    """
    Utility to generate a graphical representation for a graph.

    Attributes:
        duplicate_constants: Whether to create a separate node for
            every instance of the same constant.
        duplicate_free_variables: Whether to create a separate node
            to represent the use of a free variable, or point directly
            to that node in a different graph.
        function_in_node: Whether to print, when possible, the name
            of a node's operation directly in the node's label instead
            of creating a node for the operation and drawing an edge
            to it.
        follow_references: Whether to also print graphs that are
            called by this graph.

    """
    def __init__(self,
                 entry_points,
                 *,
                 duplicate_constants=False,
                 duplicate_free_variables=False,
                 function_in_node=False,
                 follow_references=False,
                 tooltip_gen=None,
                 class_gen=None,
                 extra_style=None,
                 beautify=True):
        """Initialize a MyiaGraphPrinter."""
        super().__init__({'layout': {
            'name': 'dagre',
            'rankDir': 'TB'
        }},
                         tooltip_gen=tooltip_gen,
                         extra_style=extra_style)
        # Graphs left to process
        if beautify:
            self.graphs = set()
            self.focus = set()
            for g in entry_points:
                self._import_graph(g)
        else:
            self.graphs = set(entry_points)
            self.focus = set(self.graphs)

        self.beautify = beautify
        self.duplicate_constants = duplicate_constants
        self.duplicate_free_variables = duplicate_free_variables
        self.function_in_node = function_in_node
        self.follow_references = follow_references
        self.labeler = NodeLabeler(function_in_node=function_in_node,
                                   relation_symbols=short_relation_symbols)
        self._class_gen = _make_class_gen(class_gen)
        # Nodes processed
        self.processed = set()
        # Nodes left to process
        self.pool = set()
        # Nodes that are to be colored as return nodes
        self.returns = set()
        # IDs for duplicated constants
        self.currid = 0

    def _import_graph(self, graph):
        mng = manage(graph, weak=True)
        graphs = set()
        parents = mng.parents
        g = graph
        while g:
            graphs.add(g)
            g = parents[g]
        clone = GraphCloner(*graphs, total=True, relation='cosmetic')
        self.graphs |= {clone[g] for g in graphs}
        self.focus.add(clone[graph])

    def name(self, x):
        """Return the name of a node."""
        return self.labeler.name(x, force=True)

    def label(self, node, fn_label=None):
        """Return the label to give to a node."""
        return self.labeler.label(node, None, fn_label=fn_label)

    def const_fn(self, node):
        """
        Return name of function, if constant.

        Given an `Apply` node of a constant function, return the
        name of that function, otherwise return None.
        """
        return self.labeler.const_fn(node)

    def add_graph(self, g):
        """Create a node for a graph."""
        if g in self.processed:
            return
        if self.beautify:
            g = cosmetic_transformer(g)
        name = self.name(g)
        argnames = [self.name(p) for p in g.parameters]
        lbl = f'{name}({", ".join(argnames)})'
        classes = ['function', 'focus' if g in self.focus else '']
        self.cynode(id=g, label=lbl, classes=' '.join(classes))
        self.processed.add(g)

    def process_node_generic(self, node, g, cl):
        """Create node and edges for a node."""
        lbl = self.label(node)

        self.cynode(id=node, label=lbl, parent=g, classes=cl)

        fn = node.inputs[0] if node.inputs else None
        if fn and fn.is_constant_graph():
            self.graphs.add(fn.value)

        for inp in node.inputs:
            if inp.is_constant_graph():
                self.cyedge(src_id=g,
                            dest_id=inp.value,
                            label=('', 'use-edge'))

        edges = []
        if fn and not (fn.is_constant() and self.function_in_node):
            edges.append((node, 'F', fn))

        edges += [(node, i + 1, inp)
                  for i, inp in enumerate(node.inputs[1:]) or []]

        self.process_edges(edges)

    def class_gen(self, node, cl=None):
        """Generate the class name for this node."""
        g = node.graph
        if cl is not None:
            pass
        elif node in self.returns:
            cl = 'output'
        elif node.is_parameter():
            cl = 'input'
            if node not in g.parameters:
                cl += ' unlisted'
        elif node.is_constant():
            cl = 'constant'
        elif node.is_special():
            cl = f'special-{type(node.special).__name__}'
        else:
            cl = 'intermediate'
        if _has_error(node.debug):
            cl += ' error'
        if self._class_gen:
            return self._class_gen(self._strip_cosmetic(node), cl)
        else:
            return cl

    def process_node(self, node):
        """Create node and edges for a node."""
        if node in self.processed:
            return

        g = node.graph
        self.follow(node)
        cl = self.class_gen(node)
        if g and g not in self.processed:
            self.add_graph(g)

        if node.inputs and node.inputs[0].is_constant():
            fn = node.inputs[0].value
            if fn in cosmetics:
                cosmetics[fn](self, node, g, cl)
            elif hasattr(fn, 'graph_display'):
                fn.graph_display(self, node, g, cl)
            else:
                self.process_node_generic(node, g, cl)
        else:
            self.process_node_generic(node, g, cl)

        self.processed.add(node)

    def process_edges(self, edges):
        """Create edges."""
        for edge in edges:
            src, lbl, dest = edge
            if dest.is_constant() and self.duplicate_constants:
                self.follow(dest)
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.label(dest),
                            classes=self.class_gen(dest, 'constant'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
            elif self.duplicate_free_variables and \
                    src.graph and dest.graph and \
                    src.graph is not dest.graph:
                self.pool.add(dest)
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.labeler.label(dest, force=True),
                            classes=self.class_gen(dest, 'freevar'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
                self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge'))
                self.cyedge(src_id=src.graph,
                            dest_id=dest.graph,
                            label=('', 'nest-edge'))
            else:
                self.pool.add(dest)
                self.cyedge(src_id=src, dest_id=dest, label=lbl)

    def process_graph(self, g):
        """Process a graph."""
        self.add_graph(g)
        for inp in g.parameters:
            self.process_node(inp)

        if not g.return_:
            return

        ret = g.return_.inputs[1]
        if not ret.is_apply() or ret.graph is not g:
            ret = g.return_

        self.returns.add(ret)
        self.pool.add(ret)

        while self.pool:
            node = self.pool.pop()
            self.process_node(node)

    def process(self):
        """Process all graphs in entry_points."""
        if self.nodes or self.edges:
            return
        while self.graphs:
            g = self.graphs.pop()
            self.process_graph(g)
        return self.nodes, self.edges

    def follow(self, node):
        """Add this node's graph if follow_references is True."""
        if node.is_constant_graph() and self.follow_references:
            self.graphs.add(node.value)
示例#2
0
class MyiaNodesPrinter(GraphPrinter):
    def __init__(self,
                 nodes,
                 *,
                 duplicate_constants=True,
                 duplicate_free_variables=True,
                 function_in_node=True,
                 tooltip_gen=None,
                 class_gen=None,
                 extra_style=None):
        super().__init__({'layout': {
            'name': 'dagre',
            'rankDir': 'TB'
        }},
                         tooltip_gen=tooltip_gen,
                         extra_style=extra_style)
        self.duplicate_constants = duplicate_constants
        self.duplicate_free_variables = duplicate_free_variables
        self.function_in_node = function_in_node
        self.labeler = NodeLabeler(function_in_node=function_in_node,
                                   relation_symbols=short_relation_symbols)
        self._class_gen = _make_class_gen(class_gen)
        self.todo = set(nodes)
        self.graphs = {node.graph for node in nodes if node.graph}
        self.focus = set()
        # Nodes that are to be colored as return nodes
        self.returns = {
            node
            for node in nodes if node.graph and node is node.graph.return_
        }
        # IDs for duplicated constants
        self.currid = 0

    def name(self, x):
        """Return the name of a node."""
        return self.labeler.name(x, force=True)

    def label(self, node, fn_label=None):
        """Return the label to give to a node."""
        return self.labeler.label(node, None, fn_label=fn_label)

    def const_fn(self, node):
        """
        Return name of function, if constant.

        Given an `Apply` node of a constant function, return the
        name of that function, otherwise return None.
        """
        return self.labeler.const_fn(node)

    def add_graph(self, g):
        """Create a node for a graph."""
        name = self.name(g)
        argnames = [self.name(p) for p in g.parameters]
        lbl = f'{name}({", ".join(argnames)})'
        classes = ['function', 'focus' if g in self.focus else '']
        self.cynode(id=g, label=lbl, classes=' '.join(classes))
        # self.processed.add(g)

    def process_node_generic(self, node, g, cl):
        """Create node and edges for a node."""
        if node.is_constant() and self.duplicate_constants:
            return

        lbl = self.label(node)

        self.cynode(id=node, label=lbl, parent=g, classes=cl)

        fn = node.inputs[0] if node.inputs else None
        if fn and fn.is_constant_graph():
            self.graphs.add(fn.value)

        for inp in node.inputs:
            if inp.is_constant_graph():
                self.cyedge(src_id=g,
                            dest_id=inp.value,
                            label=('', 'use-edge'))

        edges = []
        if fn and not (fn.is_constant() and self.function_in_node):
            edges.append((node, 'F', fn))

        edges += [(node, i + 1, inp)
                  for i, inp in enumerate(node.inputs[1:]) or []]

        self.process_edges(edges)

    def class_gen(self, node, cl=None):
        """Generate the class name for this node."""
        g = node.graph
        if cl is not None:
            pass
        elif node in self.returns:
            cl = 'output'
        elif node.is_parameter():
            cl = 'input'
            if node not in g.parameters:
                cl += ' unlisted'
        elif node.is_constant():
            cl = 'constant'
        elif node.is_special():
            cl = f'special-{type(node.special).__name__}'
        else:
            cl = 'intermediate'
        if _has_error(node.debug):
            cl += ' error'
        if self._class_gen:
            return self._class_gen(self._strip_cosmetic(node), cl)
        else:
            return cl

    def process_node(self, node):
        """Create node and edges for a node."""
        # if node in self.processed:
        #     return

        g = node.graph
        # self.follow(node)
        cl = self.class_gen(node)

        if node.inputs and node.inputs[0].is_constant():
            fn = node.inputs[0].value
            if fn in cosmetics:
                cosmetics[fn](self, node, g, cl)
            elif hasattr(fn, 'graph_display'):
                fn.graph_display(self, node, g, cl)
            else:
                self.process_node_generic(node, g, cl)
        else:
            self.process_node_generic(node, g, cl)

    def process_edges(self, edges):
        """Create edges."""
        for edge in edges:
            src, lbl, dest = edge
            if dest not in self.todo:
                continue
            if dest.is_constant() and self.duplicate_constants:
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.label(dest),
                            classes=self.class_gen(dest, 'constant'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
            elif self.duplicate_free_variables and \
                    src.graph and dest.graph and \
                    src.graph is not dest.graph:
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.labeler.label(dest, force=True),
                            classes=self.class_gen(dest, 'freevar'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
                self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge'))
                self.cyedge(src_id=src.graph,
                            dest_id=dest.graph,
                            label=('', 'nest-edge'))
            else:
                self.cyedge(src_id=src, dest_id=dest, label=lbl)

    def process(self):
        """Process all graphs in entry_points."""
        if self.nodes or self.edges:
            return
        for g in self.graphs:
            self.add_graph(g)
        for node in self.todo:
            self.process_node(node)
        return self.nodes, self.edges
示例#3
0
文件: gprint.py 项目: jangocheng/myia
class MyiaGraphPrinter(GraphPrinter):
    """
    Utility to generate a graphical representation for a graph.

    Attributes:
        duplicate_constants: Whether to create a separate node for
            every instance of the same constant.
        duplicate_free_variables: Whether to create a separate node
            to represent the use of a free variable, or point directly
            to that node in a different graph.
        function_in_node: Whether to print, when possible, the name
            of a node's operation directly in the node's label instead
            of creating a node for the operation and drawing an edge
            to it.
        follow_references: Whether to also print graphs that are
            called by this graph.

    """
    def __init__(self,
                 entry_points,
                 *,
                 duplicate_constants=False,
                 duplicate_free_variables=False,
                 function_in_node=False,
                 follow_references=False,
                 tooltip_gen=None,
                 class_gen=None,
                 extra_style=None):
        """Initialize a MyiaGraphPrinter."""
        super().__init__({'layout': {
            'name': 'dagre',
            'rankDir': 'TB'
        }},
                         tooltip_gen=tooltip_gen,
                         extra_style=extra_style)
        # Graphs left to process
        self.graphs = set(entry_points)
        self.duplicate_constants = duplicate_constants
        self.duplicate_free_variables = duplicate_free_variables
        self.function_in_node = function_in_node
        self.follow_references = follow_references
        self.labeler = NodeLabeler(function_in_node=function_in_node,
                                   relation_symbols=short_relation_symbols)
        self._class_gen = class_gen
        # Nodes processed
        self.processed = set()
        # Nodes left to process
        self.pool = set()
        # Nodes that are to be colored as return nodes
        self.returns = set()
        # IDs for duplicated constants
        self.currid = 0
        # Custom rules for nodes that represent certain calls
        self.custom_rules = {
            'return': self.process_node_return,
            'getitem': self.process_node_getitem,
            'make_tuple': self.process_node_make_tuple
        }

    def name(self, x):
        """Return the name of a node."""
        return self.labeler.name(x, force=True)

    def label(self, node, fn_label=None):
        """Return the label to give to a node."""
        return self.labeler.label(node, None, fn_label=fn_label)

    def const_fn(self, node):
        """
        Return name of function, if constant.

        Given an `Apply` node of a constant function, return the
        name of that function, otherwise return None.
        """
        return self.labeler.const_fn(node)

    def add_graph(self, g):
        """Create a node for a graph."""
        if g in self.processed:
            return
        name = self.name(g)
        argnames = [self.name(p) for p in g.parameters]
        lbl = f'{name}({", ".join(argnames)})'
        self.cynode(id=g, label=lbl, classes='function')
        self.processed.add(g)

    def process_node_return(self, node, g, cl):
        """Create node and edges for `return ...`."""
        self.cynode(id=node, label='', parent=g, classes='const_output')
        ret = node.inputs[1]
        self.process_edges([(node, '', ret)])

    def process_node_getitem(self, node, g, cl):
        """Create node and edges for `x[ct]`."""
        idx = node.inputs[2]
        if self.function_in_node and is_constant(idx):
            lbl = self.label(node, '')
            self.cynode(id=node, label=lbl, parent=g, classes=cl)
            self.process_edges([(node, (f'[{idx.value}]', 'fn-edge'),
                                 node.inputs[1])])
        else:
            self.process_node_generic(node, g, cl)

    def process_node_make_tuple(self, node, g, cl):
        """Create node and edges for `(a, b, c, ...)`."""
        if self.function_in_node:
            lbl = self.label(node, f'(...)')
            self.cynode(id=node, label=lbl, parent=g, classes=cl)
            edges = [(node, i + 1, inp)
                     for i, inp in enumerate(node.inputs[1:]) or []]
            self.process_edges(edges)
        else:
            return self.process_node_generic(node, g, cl)

    def process_node_generic(self, node, g, cl):
        """Create node and edges for a node."""
        lbl = self.label(node)

        self.cynode(id=node, label=lbl, parent=g, classes=cl)

        fn = node.inputs[0] if node.inputs else None
        if fn and is_constant_graph(fn):
            self.graphs.add(fn.value)

        edges = []
        if fn and not (is_constant(fn) and self.function_in_node):
            edges.append((node, 'F', fn))

        edges += [(node, i + 1, inp)
                  for i, inp in enumerate(node.inputs[1:]) or []]

        self.process_edges(edges)

    def class_gen(self, node, cl=None):
        """Generate the class name for this node."""
        g = node.graph
        if cl is not None:
            pass
        elif node in self.returns:
            cl = 'output'
        elif g and node in g.parameters:
            cl = 'input'
        elif is_constant(node):
            cl = 'constant'
        else:
            cl = 'intermediate'
        if self._class_gen:
            return self._class_gen(node, cl)
        else:
            return cl

    def process_node(self, node):
        """Create node and edges for a node."""
        if node in self.processed:
            return

        g = node.graph
        self.follow(node)
        cl = self.class_gen(node)

        ctfn = self.const_fn(node)
        if ctfn:
            if ctfn in self.custom_rules:
                self.custom_rules[ctfn](node, g, cl)
            else:
                self.process_node_generic(node, g, cl)
        else:
            self.process_node_generic(node, g, cl)

        self.processed.add(node)

    def process_edges(self, edges):
        """Create edges."""
        for edge in edges:
            src, lbl, dest = edge
            if is_constant(dest) and self.duplicate_constants:
                self.follow(dest)
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.label(dest),
                            classes=self.class_gen(dest, 'constant'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
            elif self.duplicate_free_variables and \
                    src.graph and dest.graph and \
                    src.graph is not dest.graph:
                self.pool.add(dest)
                cid = self.fresh_id()
                self.cynode(id=cid,
                            parent=src.graph,
                            label=self.name(dest),
                            classes=self.class_gen(dest, 'freevar'),
                            node=dest)
                self.cyedge(src_id=src, dest_id=cid, label=lbl)
                self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge'))
            else:
                self.pool.add(dest)
                self.cyedge(src_id=src, dest_id=dest, label=lbl)

    def process_graph(self, g):
        """Process a graph."""
        self.add_graph(g)
        for inp in g.parameters:
            self.process_node(inp)

        if not g.return_:
            return

        ret = g.return_.inputs[1]
        if not is_apply(ret) or ret.graph is not g:
            ret = g.return_

        self.returns.add(ret)
        self.pool.add(ret)

        while self.pool:
            node = self.pool.pop()
            self.process_node(node)

    def process(self):
        """Process all graphs in entry_points."""
        if self.nodes or self.edges:
            return
        while self.graphs:
            g = self.graphs.pop()
            self.process_graph(g)
        return self.nodes, self.edges

    def follow(self, node):
        """Add this node's graph if follow_references is True."""
        if is_constant_graph(node) and self.follow_references:
            self.graphs.add(node.value)