def quantize_forward(self, G: NNGraph, edge_recs, result=None): if result is None: result = QuantizationSet() for node in [step['node'] for step in G.graph_state.steps]: LOG.debug("quantize forward %s", node.name) in_qs = self.get_in_qs(G, edge_recs, node) if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec elif isinstance(node, ConcatParameters): qrec = self.quantize_backward(G, result, edge_recs, node) else: qrec = self.calculate_q( node, self._activation_stats.get(NodeId(node, None)), self._filter_stats.get(NodeId(node, None)), in_qs, self._min_qsnr, self._force_width) result[NodeId(node, None)] = qrec if not qrec: break for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] return result
def quantize(self, G: NNGraph) -> OrderedDict: edge_recs = {} result = OrderedDict() for step in G.graph_state.steps: node = step['node'] if isinstance(node, InputParameters): in_qs = [] else: in_qs = [ edge_recs[edge.params] for edge in G.indexed_in_edges(node.name) ] if isinstance(node, FusionParameters): fin_qs = in_qs for fnode in node.contained_nodes(): qrec = self.calculate_q( fnode, self._activation_stats.get(NodeId(node, fnode)), self._filter_stats.get(NodeId(node, fnode)), fin_qs, self._min_qsnr, self._force_width) result[NodeId(node, fnode)] = qrec fin_qs = qrec.out_qs qrec = QuantizationRecord(in_qs=in_qs, out_qs=fin_qs) else: qrec = self.calculate_q( node, self._activation_stats.get(NodeId(node, None)), self._filter_stats.get(NodeId(node, None)), in_qs, self._min_qsnr, self._force_width) result[NodeId(node, None)] = qrec if not qrec: break for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] return result
def initialize_edge_recs(G: NNGraph, qrecs): '''Initialize edge rec dictionary to current quantization settings''' edge_recs = {} for node in [step['node'] for step in G.graph_state.steps]: nodeid = NodeId(node) qrec = qrecs[nodeid] for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] return edge_recs
def quantize_forward(self, G: NNGraph, edge_recs, dtype=np.int8): for node in [step['node'] for step in G.graph_state.steps]: LOG.debug("quantize forward %s", node.name) in_qs = self.get_in_qs(G, edge_recs, node) if isinstance(node, (ConvFusionParameters, ActivationFusion)): qrec = self.quantize_fusion(G, node, in_qs, dtype) else: qrec = self.calculate_q(G, node, self._activation_stats.get( NodeId(node, None)), in_qs, dtype) self.qrecs[NodeId(node, None)] = qrec if not qrec: break for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
def propagate_forward(self, G: NNGraph, edge_recs, start_node, new_out_qrec, result): '''Propagate a new output qrec at node start_node in the graph''' found_node = False for node in [step['node'] for step in G.graph_state.steps]: if found_node: LOG.debug("propagate forwards %s", node.name) in_qs = self.get_in_qs(G, edge_recs, node) if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec elif isinstance(node, ConcatParameters): qrec = self.quantize_backward(G, result, edge_recs, node) else: qrec = self.calculate_q( node, self._activation_stats.get(NodeId(node, None)), self._filter_stats.get(NodeId(node, None)), in_qs, self._min_qsnr, self._force_width) else: if node == start_node: found_node = True qrec = self.quantize_backward(G, result, edge_recs, node, force_out=new_out_qrec) else: continue result[NodeId(node, None)] = qrec if not qrec: break for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
def report_graph(self, G: NNGraph, dot, all_ports, fake_idx, nodes=None, all_dims=False, anonymise=False, expressions=False, qrecs=None, fusions=False, parent=None): if nodes is None: nodes = set(G.nodes()) for node in G.dfs(): if node not in nodes: continue if isinstance(node, (FusionInputParameters)): continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") elif fusions and isinstance(node, FusionBase): all_ports[node] = self.report_fusion(dot, G, node, all_ports, fake_idx, all_dims=all_dims, anonymise=anonymise, expressions=expressions, qrecs=qrecs) else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) if not isinstance(node, FusionOutputParameters): names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node( node.name, nohtml(names), shape='record', xlabel=str(node.step_idx), color="blue" if node.is_not_generated else "black") for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) edge_label, edge_error = self.in_label(G, edge, qrecs, parent=parent) dot.edge(fake_name, to_node_id, xlabel=edge_label, color="red" if edge_error else "black") if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) edge_label, edge_error = self.out_label( G, edge, qrecs, parent=parent, from_node=not isinstance(edge.from_node, FusionInputParameters), to_node=not isinstance(edge.to_node, FusionOutputParameters)) dot.edge(from_node_id, fake_name, xlabel=edge_label, color="red" if edge_error else "black")
def report(self, G: NNGraph, nodes=None, graph_format='PDF', all_dims=False, filename=None, view=True, anonymise=False, expressions=False, quant_labels=False): if nodes is None: nodes = set(G.nodes()) self.init_name_cache() all_ports = {} graph_name = G.graphname if hasattr(G, 'graphname') else 'graph' dot = Digraph(comment=graph_name, format=graph_format, node_attr={'height': '.1'}, edge_attr={'fontsize': '10.0'}) fake_idx = 0 for node in G.dfs(): if node not in nodes: continue if expressions and isinstance(node, ExpressionFusionParameters): all_ports[node] = self.report_expression( dot, G, node, anonymise=anonymise, report_quantized=expressions == "quantized") else: num_in_edges = len(G.indexed_in_edges(node.name)) num_out_edges = len(G.indexed_out_edges(node.name)) ports = all_ports.setdefault(node, [None] * 2) names = self.build_nodebox(node, ports, num_in_edges, num_out_edges, anon=anonymise) dot.node(node.name, nohtml(names), shape='record', xlabel=str(node.step_idx)) for edge in G.in_edges(node.name): if edge.from_node not in nodes: if not all_dims: continue out_port, in_port = self.get_ports(all_ports, edge) if edge.from_node in nodes: from_node_id = self.get_from_id(all_ports, edge, out_port) to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(from_node_id, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) else: fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='point', fillcolor='black') to_node_id = self.get_to_id(all_ports, edge, in_port) dot.edge(fake_name, to_node_id, xlabel=self.in_label(G, node, edge.to_idx, quant_labels)) if not all_dims: continue for edge_group in G.indexed_out_edges(node.name): if any(edge.to_node in nodes for edge in edge_group): continue edge = edge_group[0] out_port, _ = self.get_ports(all_ports, edge) fake_name = f'fake_{fake_idx}' fake_idx += 1 dot.node(fake_name, shape='plaintext', label=' ', fillcolor='black') from_node_id = self.get_from_id(all_ports, edge, out_port) dot.edge(from_node_id, fake_name, xlabel=self.out_label(G, node, edge.from_idx, quant_labels)) # dot = dot.unflatten(stagger=2) if filename: dot.render(filename, cleanup=True) if view: filename = tempfile.mktemp('.gv') dot.view(filename, cleanup=True, quiet=True) self.reset_name_cache()
def quantize_backward(self, G: NNGraph, result, edge_recs, node, force_out=None): LOG.debug("quantize backwards %s", node.name) recalculated = False while True: in_qs = self.get_in_qs(G, edge_recs, node) if self.is_filter_node(node): if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) if force_out and force_out.q is not None and qrec.out_qs[ 0].q < force_out.q: if recalculated: raise NotImplementedError( "no quantization solution found") bits_to_gain = force_out.q - qrec.q if bits_to_gain > in_qs[0].q: raise NotImplementedError() # Try to adjust the inputs to satisfy and then # recalculate pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=force_out.bits, q=in_qs[0].q - bits_to_gain, signed=True)) elif isinstance(node, ConcatParameters): assert not recalculated max_width = max(in_q.bits for in_q in in_qs) min_q = min(in_q.q for in_q in in_qs) if force_out: if not self.satisfied(force_out.bits, max_width): max_width = force_out.bits if not self.satisfied(force_out.q, min_q): min_q = force_out.q LOG.debug("normalizing concat to %s", QType(bits=max_width, q=min_q, signed=True)) for pidx, pnode in enumerate( [edge.from_node for edge in G.in_edges(node.name)]): pqrec = in_qs[pidx] if pqrec.q != min_q or pqrec.bits != max_width: self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=max_width, q=min_q, signed=True)) o_q = QType(bits=max_width, q=min_q, signed=True) qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs( G, edge_recs, node), out_qs=[o_q]) elif isinstance(node, SoftMaxParameters): raise NotImplementedError( "softmax kernel cannot change width or q") else: if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) o_q = qrec.out_qs[0] if not (self.satisfied(force_out.q, o_q.q) and self.satisfied(force_out.bits, o_q.bits)): if recalculated: raise NotImplementedError( "no quantization solution found") if len(G.in_edges(node.name)) > 1: raise NotImplementedError( "Nodes with multiple input edges \ need custom handling") pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=force_out) for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] result[NodeId(node, None)] = qrec o_q = qrec.out_qs[0] if self.satisfied_force(force_out, o_q): break if recalculated: raise NotImplementedError("no quantization solution found") LOG.debug("recalculate %s", node.name) recalculated = True LOG.debug("back complete %s %s", node.name, qrec) return qrec