def fix_split_in_edges(G: NNGraph): for split in [node for node in G.nodes() if isinstance(node, SplitParameters)]: in_edge = G.in_edges(split.name)[0] if in_edge.to_idx == 0: continue G.remove_edge(in_edge) G.add_edge(NNEdge(in_edge.from_node, in_edge.to_node, from_idx=in_edge.from_idx))
def create_graph(self, filename, opts): opts = self.get_opts(opts) self._name_cache = {} add_sys_path(os.path.dirname(__file__)) buf = open(filename, "rb").read() model = Model.GetRootAsModel(buf, 0) LOG.info("Importing TFLITE model version %s", model.Version()) check(model.Version() == 3, "Only support version 3 graphs at present") if model.SubgraphsLength() > 1: LOG.warning( "nntool only supports one subgraph. There may be errors loading this graph." ) G = NNGraph(model=model, filename=filename, name=opts.get('name'), constant_store=ConstantStore()) if opts.get('load_quantization'): G.quantization = QuantizationSet() G.has_quantized_parameters = True G.quantization.schemes_present.add('SQ8') self._import_tflite_graph(G, model, opts) clean_dangling_nodes(G) fix_split_in_edges(G) MatchDuplicateConstants().match(G) # DrawGraphReporter().report(G) G.add_dimensions() remove_concats(G) if opts['remove_quantize_ops']: RemoveQuantizeOperators().match(G) G.add_dimensions() if opts.get('load_quantization'): # get rid of qrecs on nodes that were not used to_remove = [] for nid in G.quantization: if nid.node_name not in G: to_remove.append(nid) for nid in to_remove: del G.quantization[nid] nodes_with_bad_quantization = self.find_nodes_with_bad_quantization( G) quantizer = UnifiedQuantizer.from_quantized_graph(G) # check for quantization problems # 1) need to force softmax/Sigmoid input to POW2 quantization # 2) need to check that all concats and splits have same input and # output quantization # 3) Need to check that all nodes have qrecs and that they are consistent nodes_with_bad_quantization |= set( G.nodes(node_classes=(ConcatParameters, SoftMaxParameters, SplitParameters, SigmoidActivationParameters))) G.quantization = quantizer.quantize( G, start_nodes=nodes_with_bad_quantization) G.add_dimensions() return G
def create_graph(filename, opts): cfg = read_cfg(filename) out_graph = NNGraph(model=cfg, filename=filename, name=opts.get('name'), value_cache=opts.get('value_cache')) create_subgraph(out_graph, cfg) leaf_nodes = list([n for n in out_graph.nodes()\ if out_graph.out_degree(n) == 0 and out_graph.in_degree(n) > 0]) for node in leaf_nodes: out_graph.add_edge(node, out_graph.add_output(), order=0) return out_graph
def report(self, G: NNGraph, nodes=None) -> Tabular: if nodes is None: nodes = G.nodes() nodes = sorted(nodes, key=lambda x: x.step_idx) start_step = nodes[0].step_idx end_step = nodes[-1].step_idx steps = G.graph_state.steps liveness = G.graph_state.liveness first_node = steps[start_step]['node'] active_order = "x".join(first_node.out_dims[0].order) tab = Tabular() self.do_headers(active_order, tab) max_active = 0 tot_params = 0 tot_ops = 0 for i, node, active, params_size, ops in graph_walk(steps, liveness): if node.step_idx < start_step or node.step_idx > end_step: continue tot_params += params_size if ops: tot_ops += ops if active > max_active: max_active = active if self._show_constants or not isinstance(node, ConstantInputParameters): self.do_operation(node, G, tab, i, active, params_size, ops) if start_step != end_step: self.check_do_totals(tab, max_active, tot_params, tot_ops) return tab
def report_graph(self, G: NNGraph, dot, all_ports, fake_idx, nodes=None, all_dims=False, anonymise=False, expressions=False, qrecs=None, fusions=False, parent=None): if nodes is None: nodes = set(G.nodes()) for node in G.dfs(): if node not in nodes: continue if isinstance(node, (FusionInputParameters)): continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") elif fusions and isinstance(node, FusionBase): all_ports[node] = self.report_fusion(dot, G, node, all_ports, fake_idx, all_dims=all_dims, anonymise=anonymise, expressions=expressions, qrecs=qrecs) else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) if not isinstance(node, FusionOutputParameters): names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node( node.name, nohtml(names), shape='record', xlabel=str(node.step_idx), color="blue" if node.is_not_generated else "black") for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label(G, edge, qrecs, parent=parent) dot.edge(fake_name, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) edge_label, edge_error = self.out_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, fake_name, xlabel=edge_label, color="red" if edge_error else "black")
def report(self, G: NNGraph, nodes=None, graph_format='PDF', all_dims=False, filename=None, view=True, anonymise=False, expressions=False, quant_labels=False): if nodes is None: nodes = set(G.nodes()) self.init_name_cache() all_ports = {} graph_name = G.graphname if hasattr(G, 'graphname') else 'graph' dot = Digraph(comment=graph_name, format=graph_format, node_attr={'height': '.1'}, edge_attr={'fontsize': '10.0'}) fake_idx = 0 for node in G.dfs(): if node not in nodes: continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node(node.name, nohtml(names), shape='record', xlabel=str(node.step_idx)) for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(from_node_id, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(fake_name, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) dot.edge(from_node_id, fake_name, xlabel=self.out_label(G, node, edge.from_idx, quant_labels)) # dot = dot.unflatten(stagger=2) if filename: dot.render(filename, cleanup=True) if view: filename = tempfile.mktemp('.gv') dot.view(filename, cleanup=True, quiet=True) self.reset_name_cache()