def fix_split_in_edges(G: NNGraph): for split in [node for node in G.nodes() if isinstance(node, SplitParameters)]: in_edge = G.in_edges(split.name)[0] if in_edge.to_idx == 0: continue G.remove_edge(in_edge) G.add_edge(NNEdge(in_edge.from_node, in_edge.to_node, from_idx=in_edge.from_idx))
def propagate_downwards(G: NNGraph): for node in G.dfs(): # First propagate the in dim hints to the out dim hints # Any node that does not want this to happen should set its out dim hints if node.in_dims_hint is not None: if isinstance(node, ReshapeParameters): if len(node.old_shape) == len(node.in_dims_hint[0]): LOG.debug("set reshape %s in dims hint %s", node.name, node.in_dims_hint[0]) node.old_shape.apply_naming_hints(node.in_dims_hint[0]) elif isinstance(node, GlobalPoolParameters): if node.keep_dims: node.out_dims_hint = deepcopy(node.in_dims_hint) elif isinstance(node, MatrixBroadcastedLinearOpParameters): max_hint = None for hint in node.in_dims_hint: if hint is not None and (max_hint is None or len(hint) > len(max_hint)): max_hint = hint if max_hint is not None: node.out_dims_hint = [max_hint] elif isinstance(node, ConcatParameters): # if any incoming edge of the concat doesn't have a hint # set it the same as the others any_in_hint = next( (hint for hint in node.in_dims_hint if hint is not None), None) if any_in_hint: LOG.debug("set concat %s in dims hint %s", node.name, any_in_hint) for edge in G.in_edges(node.name): if not node.in_dims_hint[edge.to_idx]: node.in_dims_hint[edge.to_idx] = any_in_hint node.out_dims_hint = [any_in_hint] else: if node.out_dims_hint is None: node.out_dims_hint = deepcopy(node.in_dims_hint) # if we have an out dim hint then propagate it to downstream nodes if node.out_dims_hint is not None: LOG.debug("propagate down hint from %s", node.name) for edge in G.out_edges(node.name): hint = node.out_dims_hint[edge.from_idx] if hint is None: continue if edge.to_node.in_dims_hint is None: edge.to_node.in_dims_hint = SparseList() if edge.to_node.in_dims_hint[edge.to_idx] is None: edge.to_node.in_dims_hint[edge.to_idx] = hint
def propagate_upwards(G: NNGraph): for node in G.dfs(reverse=True): # First propagate the out dim hints to the in dim hints # Any node that does not want this to happen should set its in dim hints if node.out_dims_hint is not None: if isinstance(node, ReshapeParameters): if len(node.shape) < len(node.out_dims_hint[0]): node.shape = Dim.unnamed(( [1] * (len(node.out_dims_hint[0]) - len(node.shape))) + node.shape.shape) node.shape.apply_naming_hints(node.out_dims_hint[0]) if node.in_dims_hint is None: node.in_dims_hint = SparseList( [["%s" % i for i in range(len(node.old_shape))]]) elif isinstance(node, MatrixBroadcastedLinearOpParameters): node.in_dims_hint = [node.out_dims_hint[0]] * 2 elif isinstance(node, MatrixMulParameters): continue elif isinstance(node, GlobalPoolParameters): if node.keep_dims: node.in_dims_hint = deepcopy(node.out_dims_hint) elif isinstance( node, ConstantInputParameters) and not node.dims.is_named: node.dims.apply_naming_hints(node.out_dims_hint[0]) else: if node.in_dims_hint is None: node.in_dims_hint = deepcopy(node.out_dims_hint) # if we have an in dim hint then propagate it to upstream nodes if node.in_dims_hint is not None: for edge in G.in_edges(node.name): hint = node.in_dims_hint[edge.to_idx] if hint is None: continue if edge.from_node.out_dims_hint is None: edge.from_node.out_dims_hint = SparseList() if edge.from_node.out_dims_hint[edge.from_idx] is None: edge.from_node.out_dims_hint[edge.from_idx] = hint if isinstance(edge.from_node, InputParameters): assert edge.from_idx == 0, "input node should only have one output" dims_len = len(edge.from_node.dims) hint_len = len(hint) if dims_len < hint_len: edge.from_node.dims = Dim.unnamed( [1] * (hint_len - dims_len) + edge.from_node.dims.shape)
def propagate_upwards(G: NNGraph): for node in G.dfs(reverse=True): # First propagate the out dim hints to the in dim hints # Any node that does not want this to happen should set its in dim hints if node.out_dims_hint is not None: if isinstance(node, ReshapeParameters): assert len(node.shape) == len(node.out_dims_hint[0]) node.shape.apply_naming_hints(node.out_dims_hint[0]) if node.in_dims_hint is None: node.in_dims_hint = SparseList([["%s" % i for i in range(len(node.old_shape))]]) else: if node.in_dims_hint is None: node.in_dims_hint = deepcopy(node.out_dims_hint) # if we have an in dim hint then propagate it to upstream nodes if node.in_dims_hint is not None: for edge in G.in_edges(node.name): hint = node.in_dims_hint[edge.to_idx] if edge.from_node.out_dims_hint is None: edge.from_node.out_dims_hint = SparseList() if edge.from_node.out_dims_hint[edge.from_idx] is None: edge.from_node.out_dims_hint[edge.from_idx] = hint
def report_graph(self, G: NNGraph, dot, all_ports, fake_idx, nodes=None, all_dims=False, anonymise=False, expressions=False, qrecs=None, fusions=False, parent=None): if nodes is None: nodes = set(G.nodes()) for node in G.dfs(): if node not in nodes: continue if isinstance(node, (FusionInputParameters)): continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") elif fusions and isinstance(node, FusionBase): all_ports[node] = self.report_fusion(dot, G, node, all_ports, fake_idx, all_dims=all_dims, anonymise=anonymise, expressions=expressions, qrecs=qrecs) else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) if not isinstance(node, FusionOutputParameters): names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node( node.name, nohtml(names), shape='record', xlabel=str(node.step_idx), color="blue" if node.is_not_generated else "black") for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label(G, edge, qrecs, parent=parent) dot.edge(fake_name, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) edge_label, edge_error = self.out_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, fake_name, xlabel=edge_label, color="red" if edge_error else "black")
def report(self, G: NNGraph, nodes=None, graph_format='PDF', all_dims=False, filename=None, view=True, anonymise=False, expressions=False, quant_labels=False): if nodes is None: nodes = set(G.nodes()) self.init_name_cache() all_ports = {} graph_name = G.graphname if hasattr(G, 'graphname') else 'graph' dot = Digraph(comment=graph_name, format=graph_format, node_attr={'height': '.1'}, edge_attr={'fontsize': '10.0'}) fake_idx = 0 for node in G.dfs(): if node not in nodes: continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node(node.name, nohtml(names), shape='record', xlabel=str(node.step_idx)) for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(from_node_id, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(fake_name, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) dot.edge(from_node_id, fake_name, xlabel=self.out_label(G, node, edge.from_idx, quant_labels)) # dot = dot.unflatten(stagger=2) if filename: dot.render(filename, cleanup=True) if view: filename = tempfile.mktemp('.gv') dot.view(filename, cleanup=True, quiet=True) self.reset_name_cache()
def quantize_backward(self, G: NNGraph, result, edge_recs, node, force_out=None): LOG.debug("quantize backwards %s", node.name) recalculated = False while True: in_qs = self.get_in_qs(G, edge_recs, node) if self.is_filter_node(node): if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) if force_out and force_out.q is not None and qrec.out_qs[ 0].q < force_out.q: if recalculated: raise NotImplementedError( "no quantization solution found") bits_to_gain = force_out.q - qrec.q if bits_to_gain > in_qs[0].q: raise NotImplementedError() # Try to adjust the inputs to satisfy and then # recalculate pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=force_out.bits, q=in_qs[0].q - bits_to_gain, signed=True)) elif isinstance(node, ConcatParameters): assert not recalculated max_width = max(in_q.bits for in_q in in_qs) min_q = min(in_q.q for in_q in in_qs) if force_out: if not self.satisfied(force_out.bits, max_width): max_width = force_out.bits if not self.satisfied(force_out.q, min_q): min_q = force_out.q LOG.debug("normalizing concat to %s", QType(bits=max_width, q=min_q, signed=True)) for pidx, pnode in enumerate( [edge.from_node for edge in G.in_edges(node.name)]): pqrec = in_qs[pidx] if pqrec.q != min_q or pqrec.bits != max_width: self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=max_width, q=min_q, signed=True)) o_q = QType(bits=max_width, q=min_q, signed=True) qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs( G, edge_recs, node), out_qs=[o_q]) elif isinstance(node, SoftMaxParameters): raise NotImplementedError( "softmax kernel cannot change width or q") else: if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) o_q = qrec.out_qs[0] if not (self.satisfied(force_out.q, o_q.q) and self.satisfied(force_out.bits, o_q.bits)): if recalculated: raise NotImplementedError( "no quantization solution found") if len(G.in_edges(node.name)) > 1: raise NotImplementedError( "Nodes with multiple input edges \ need custom handling") pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=force_out) for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] result[NodeId(node, None)] = qrec o_q = qrec.out_qs[0] if self.satisfied_force(force_out, o_q): break if recalculated: raise NotImplementedError("no quantization solution found") LOG.debug("recalculate %s", node.name) recalculated = True LOG.debug("back complete %s %s", node.name, qrec) return qrec