class MyiaGraphPrinter(GraphPrinter): """ Utility to generate a graphical representation for a graph. Attributes: duplicate_constants: Whether to create a separate node for every instance of the same constant. duplicate_free_variables: Whether to create a separate node to represent the use of a free variable, or point directly to that node in a different graph. function_in_node: Whether to print, when possible, the name of a node's operation directly in the node's label instead of creating a node for the operation and drawing an edge to it. follow_references: Whether to also print graphs that are called by this graph. """ def __init__(self, entry_points, *, duplicate_constants=False, duplicate_free_variables=False, function_in_node=False, follow_references=False, tooltip_gen=None, class_gen=None, extra_style=None, beautify=True): """Initialize a MyiaGraphPrinter.""" super().__init__({'layout': { 'name': 'dagre', 'rankDir': 'TB' }}, tooltip_gen=tooltip_gen, extra_style=extra_style) # Graphs left to process if beautify: self.graphs = set() self.focus = set() for g in entry_points: self._import_graph(g) else: self.graphs = set(entry_points) self.focus = set(self.graphs) self.beautify = beautify self.duplicate_constants = duplicate_constants self.duplicate_free_variables = duplicate_free_variables self.function_in_node = function_in_node self.follow_references = follow_references self.labeler = NodeLabeler(function_in_node=function_in_node, relation_symbols=short_relation_symbols) self._class_gen = _make_class_gen(class_gen) # Nodes processed self.processed = set() # Nodes left to process self.pool = set() # Nodes that are to be colored as return nodes self.returns = set() # IDs for duplicated constants self.currid = 0 def _import_graph(self, graph): mng = manage(graph, weak=True) graphs = set() parents = mng.parents g = graph while g: graphs.add(g) g = parents[g] clone = GraphCloner(*graphs, total=True, relation='cosmetic') self.graphs |= {clone[g] for g in graphs} self.focus.add(clone[graph]) def name(self, x): """Return the name of a node.""" return self.labeler.name(x, force=True) def label(self, node, fn_label=None): """Return the label to give to a node.""" return self.labeler.label(node, None, fn_label=fn_label) def const_fn(self, node): """ Return name of function, if constant. Given an `Apply` node of a constant function, return the name of that function, otherwise return None. """ return self.labeler.const_fn(node) def add_graph(self, g): """Create a node for a graph.""" if g in self.processed: return if self.beautify: g = cosmetic_transformer(g) name = self.name(g) argnames = [self.name(p) for p in g.parameters] lbl = f'{name}({", ".join(argnames)})' classes = ['function', 'focus' if g in self.focus else ''] self.cynode(id=g, label=lbl, classes=' '.join(classes)) self.processed.add(g) def process_node_generic(self, node, g, cl): """Create node and edges for a node.""" lbl = self.label(node) self.cynode(id=node, label=lbl, parent=g, classes=cl) fn = node.inputs[0] if node.inputs else None if fn and fn.is_constant_graph(): self.graphs.add(fn.value) for inp in node.inputs: if inp.is_constant_graph(): self.cyedge(src_id=g, dest_id=inp.value, label=('', 'use-edge')) edges = [] if fn and not (fn.is_constant() and self.function_in_node): edges.append((node, 'F', fn)) edges += [(node, i + 1, inp) for i, inp in enumerate(node.inputs[1:]) or []] self.process_edges(edges) def class_gen(self, node, cl=None): """Generate the class name for this node.""" g = node.graph if cl is not None: pass elif node in self.returns: cl = 'output' elif node.is_parameter(): cl = 'input' if node not in g.parameters: cl += ' unlisted' elif node.is_constant(): cl = 'constant' elif node.is_special(): cl = f'special-{type(node.special).__name__}' else: cl = 'intermediate' if _has_error(node.debug): cl += ' error' if self._class_gen: return self._class_gen(self._strip_cosmetic(node), cl) else: return cl def process_node(self, node): """Create node and edges for a node.""" if node in self.processed: return g = node.graph self.follow(node) cl = self.class_gen(node) if g and g not in self.processed: self.add_graph(g) if node.inputs and node.inputs[0].is_constant(): fn = node.inputs[0].value if fn in cosmetics: cosmetics[fn](self, node, g, cl) elif hasattr(fn, 'graph_display'): fn.graph_display(self, node, g, cl) else: self.process_node_generic(node, g, cl) else: self.process_node_generic(node, g, cl) self.processed.add(node) def process_edges(self, edges): """Create edges.""" for edge in edges: src, lbl, dest = edge if dest.is_constant() and self.duplicate_constants: self.follow(dest) cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.label(dest), classes=self.class_gen(dest, 'constant'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) elif self.duplicate_free_variables and \ src.graph and dest.graph and \ src.graph is not dest.graph: self.pool.add(dest) cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.labeler.label(dest, force=True), classes=self.class_gen(dest, 'freevar'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge')) self.cyedge(src_id=src.graph, dest_id=dest.graph, label=('', 'nest-edge')) else: self.pool.add(dest) self.cyedge(src_id=src, dest_id=dest, label=lbl) def process_graph(self, g): """Process a graph.""" self.add_graph(g) for inp in g.parameters: self.process_node(inp) if not g.return_: return ret = g.return_.inputs[1] if not ret.is_apply() or ret.graph is not g: ret = g.return_ self.returns.add(ret) self.pool.add(ret) while self.pool: node = self.pool.pop() self.process_node(node) def process(self): """Process all graphs in entry_points.""" if self.nodes or self.edges: return while self.graphs: g = self.graphs.pop() self.process_graph(g) return self.nodes, self.edges def follow(self, node): """Add this node's graph if follow_references is True.""" if node.is_constant_graph() and self.follow_references: self.graphs.add(node.value)
class MyiaNodesPrinter(GraphPrinter): def __init__(self, nodes, *, duplicate_constants=True, duplicate_free_variables=True, function_in_node=True, tooltip_gen=None, class_gen=None, extra_style=None): super().__init__({'layout': { 'name': 'dagre', 'rankDir': 'TB' }}, tooltip_gen=tooltip_gen, extra_style=extra_style) self.duplicate_constants = duplicate_constants self.duplicate_free_variables = duplicate_free_variables self.function_in_node = function_in_node self.labeler = NodeLabeler(function_in_node=function_in_node, relation_symbols=short_relation_symbols) self._class_gen = _make_class_gen(class_gen) self.todo = set(nodes) self.graphs = {node.graph for node in nodes if node.graph} self.focus = set() # Nodes that are to be colored as return nodes self.returns = { node for node in nodes if node.graph and node is node.graph.return_ } # IDs for duplicated constants self.currid = 0 def name(self, x): """Return the name of a node.""" return self.labeler.name(x, force=True) def label(self, node, fn_label=None): """Return the label to give to a node.""" return self.labeler.label(node, None, fn_label=fn_label) def const_fn(self, node): """ Return name of function, if constant. Given an `Apply` node of a constant function, return the name of that function, otherwise return None. """ return self.labeler.const_fn(node) def add_graph(self, g): """Create a node for a graph.""" name = self.name(g) argnames = [self.name(p) for p in g.parameters] lbl = f'{name}({", ".join(argnames)})' classes = ['function', 'focus' if g in self.focus else ''] self.cynode(id=g, label=lbl, classes=' '.join(classes)) # self.processed.add(g) def process_node_generic(self, node, g, cl): """Create node and edges for a node.""" if node.is_constant() and self.duplicate_constants: return lbl = self.label(node) self.cynode(id=node, label=lbl, parent=g, classes=cl) fn = node.inputs[0] if node.inputs else None if fn and fn.is_constant_graph(): self.graphs.add(fn.value) for inp in node.inputs: if inp.is_constant_graph(): self.cyedge(src_id=g, dest_id=inp.value, label=('', 'use-edge')) edges = [] if fn and not (fn.is_constant() and self.function_in_node): edges.append((node, 'F', fn)) edges += [(node, i + 1, inp) for i, inp in enumerate(node.inputs[1:]) or []] self.process_edges(edges) def class_gen(self, node, cl=None): """Generate the class name for this node.""" g = node.graph if cl is not None: pass elif node in self.returns: cl = 'output' elif node.is_parameter(): cl = 'input' if node not in g.parameters: cl += ' unlisted' elif node.is_constant(): cl = 'constant' elif node.is_special(): cl = f'special-{type(node.special).__name__}' else: cl = 'intermediate' if _has_error(node.debug): cl += ' error' if self._class_gen: return self._class_gen(self._strip_cosmetic(node), cl) else: return cl def process_node(self, node): """Create node and edges for a node.""" # if node in self.processed: # return g = node.graph # self.follow(node) cl = self.class_gen(node) if node.inputs and node.inputs[0].is_constant(): fn = node.inputs[0].value if fn in cosmetics: cosmetics[fn](self, node, g, cl) elif hasattr(fn, 'graph_display'): fn.graph_display(self, node, g, cl) else: self.process_node_generic(node, g, cl) else: self.process_node_generic(node, g, cl) def process_edges(self, edges): """Create edges.""" for edge in edges: src, lbl, dest = edge if dest not in self.todo: continue if dest.is_constant() and self.duplicate_constants: cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.label(dest), classes=self.class_gen(dest, 'constant'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) elif self.duplicate_free_variables and \ src.graph and dest.graph and \ src.graph is not dest.graph: cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.labeler.label(dest, force=True), classes=self.class_gen(dest, 'freevar'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge')) self.cyedge(src_id=src.graph, dest_id=dest.graph, label=('', 'nest-edge')) else: self.cyedge(src_id=src, dest_id=dest, label=lbl) def process(self): """Process all graphs in entry_points.""" if self.nodes or self.edges: return for g in self.graphs: self.add_graph(g) for node in self.todo: self.process_node(node) return self.nodes, self.edges
class MyiaGraphPrinter(GraphPrinter): """ Utility to generate a graphical representation for a graph. Attributes: duplicate_constants: Whether to create a separate node for every instance of the same constant. duplicate_free_variables: Whether to create a separate node to represent the use of a free variable, or point directly to that node in a different graph. function_in_node: Whether to print, when possible, the name of a node's operation directly in the node's label instead of creating a node for the operation and drawing an edge to it. follow_references: Whether to also print graphs that are called by this graph. """ def __init__(self, entry_points, *, duplicate_constants=False, duplicate_free_variables=False, function_in_node=False, follow_references=False, tooltip_gen=None, class_gen=None, extra_style=None): """Initialize a MyiaGraphPrinter.""" super().__init__({'layout': { 'name': 'dagre', 'rankDir': 'TB' }}, tooltip_gen=tooltip_gen, extra_style=extra_style) # Graphs left to process self.graphs = set(entry_points) self.duplicate_constants = duplicate_constants self.duplicate_free_variables = duplicate_free_variables self.function_in_node = function_in_node self.follow_references = follow_references self.labeler = NodeLabeler(function_in_node=function_in_node, relation_symbols=short_relation_symbols) self._class_gen = class_gen # Nodes processed self.processed = set() # Nodes left to process self.pool = set() # Nodes that are to be colored as return nodes self.returns = set() # IDs for duplicated constants self.currid = 0 # Custom rules for nodes that represent certain calls self.custom_rules = { 'return': self.process_node_return, 'getitem': self.process_node_getitem, 'make_tuple': self.process_node_make_tuple } def name(self, x): """Return the name of a node.""" return self.labeler.name(x, force=True) def label(self, node, fn_label=None): """Return the label to give to a node.""" return self.labeler.label(node, None, fn_label=fn_label) def const_fn(self, node): """ Return name of function, if constant. Given an `Apply` node of a constant function, return the name of that function, otherwise return None. """ return self.labeler.const_fn(node) def add_graph(self, g): """Create a node for a graph.""" if g in self.processed: return name = self.name(g) argnames = [self.name(p) for p in g.parameters] lbl = f'{name}({", ".join(argnames)})' self.cynode(id=g, label=lbl, classes='function') self.processed.add(g) def process_node_return(self, node, g, cl): """Create node and edges for `return ...`.""" self.cynode(id=node, label='', parent=g, classes='const_output') ret = node.inputs[1] self.process_edges([(node, '', ret)]) def process_node_getitem(self, node, g, cl): """Create node and edges for `x[ct]`.""" idx = node.inputs[2] if self.function_in_node and is_constant(idx): lbl = self.label(node, '') self.cynode(id=node, label=lbl, parent=g, classes=cl) self.process_edges([(node, (f'[{idx.value}]', 'fn-edge'), node.inputs[1])]) else: self.process_node_generic(node, g, cl) def process_node_make_tuple(self, node, g, cl): """Create node and edges for `(a, b, c, ...)`.""" if self.function_in_node: lbl = self.label(node, f'(...)') self.cynode(id=node, label=lbl, parent=g, classes=cl) edges = [(node, i + 1, inp) for i, inp in enumerate(node.inputs[1:]) or []] self.process_edges(edges) else: return self.process_node_generic(node, g, cl) def process_node_generic(self, node, g, cl): """Create node and edges for a node.""" lbl = self.label(node) self.cynode(id=node, label=lbl, parent=g, classes=cl) fn = node.inputs[0] if node.inputs else None if fn and is_constant_graph(fn): self.graphs.add(fn.value) edges = [] if fn and not (is_constant(fn) and self.function_in_node): edges.append((node, 'F', fn)) edges += [(node, i + 1, inp) for i, inp in enumerate(node.inputs[1:]) or []] self.process_edges(edges) def class_gen(self, node, cl=None): """Generate the class name for this node.""" g = node.graph if cl is not None: pass elif node in self.returns: cl = 'output' elif g and node in g.parameters: cl = 'input' elif is_constant(node): cl = 'constant' else: cl = 'intermediate' if self._class_gen: return self._class_gen(node, cl) else: return cl def process_node(self, node): """Create node and edges for a node.""" if node in self.processed: return g = node.graph self.follow(node) cl = self.class_gen(node) ctfn = self.const_fn(node) if ctfn: if ctfn in self.custom_rules: self.custom_rules[ctfn](node, g, cl) else: self.process_node_generic(node, g, cl) else: self.process_node_generic(node, g, cl) self.processed.add(node) def process_edges(self, edges): """Create edges.""" for edge in edges: src, lbl, dest = edge if is_constant(dest) and self.duplicate_constants: self.follow(dest) cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.label(dest), classes=self.class_gen(dest, 'constant'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) elif self.duplicate_free_variables and \ src.graph and dest.graph and \ src.graph is not dest.graph: self.pool.add(dest) cid = self.fresh_id() self.cynode(id=cid, parent=src.graph, label=self.name(dest), classes=self.class_gen(dest, 'freevar'), node=dest) self.cyedge(src_id=src, dest_id=cid, label=lbl) self.cyedge(src_id=cid, dest_id=dest, label=(lbl, 'link-edge')) else: self.pool.add(dest) self.cyedge(src_id=src, dest_id=dest, label=lbl) def process_graph(self, g): """Process a graph.""" self.add_graph(g) for inp in g.parameters: self.process_node(inp) if not g.return_: return ret = g.return_.inputs[1] if not is_apply(ret) or ret.graph is not g: ret = g.return_ self.returns.add(ret) self.pool.add(ret) while self.pool: node = self.pool.pop() self.process_node(node) def process(self): """Process all graphs in entry_points.""" if self.nodes or self.edges: return while self.graphs: g = self.graphs.pop() self.process_graph(g) return self.nodes, self.edges def follow(self, node): """Add this node's graph if follow_references is True.""" if is_constant_graph(node) and self.follow_references: self.graphs.add(node.value)