def replace_function(self, G: NNGraph, subgraph: GraphView): step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, FcParameters): linear_name = node.name + "_fusion" break LOG.info("fusing nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def replace_function(self, G: GraphView, subgraph: GraphView): if not self.validate_match(subgraph): raise DontReplaceError() step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, Conv2DParameters): conv_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def quantize_fusion(self, G: NNGraph, node: ConvFusionParameters, in_qs, force_out=None) -> SymmetricQuantizationRecord: if node.fusion_type == 'conv_active': result = OrderedDict() nodes = node.contained_nodes() conv_node = nodes[0] conv_astats = self._activation_stats.get(NodeId(node, conv_node)) conv_qrec = self.calculate_filter_q(conv_node, conv_astats, in_q=in_qs[0], force_width=self._force_width, out_as_acc=True) result[NodeId(node, conv_node)] = conv_qrec act_node = nodes[1] act_astats = self._activation_stats.get(NodeId(node, act_node)) if force_out and force_out.bits: act_max_q = self.compute_activation_out_maxq( act_node, force_out.bits) if force_out.q is not None: if (act_max_q is not None and force_out.q > act_max_q ) or force_out.q > conv_qrec.out_qs[0].q: # We cannot shift left in the kernel # TODO - This should try to increase the input q and perhaps the width # Unlikely to happen raise NotImplementedError() act_o_q = QType(bits=force_out.bits, q=force_out.q, signed=True) else: if act_max_q is not None: act_o_q.q = min(act_max_q, act_o_q.q) else: act_o_q = QType.from_min_max( max_val=act_astats['range_out'][0]['max'], min_val=act_astats['range_out'][0]['min'], bits=self._force_width) act_o_q.q = min(act_o_q.q, conv_qrec.out_qs[0].q) if force_out and force_out.q: if force_out.q > act_max_q or force_out.q > conv_qrec.out_qs[ 0].q: # We cannot shift left in the kernel # TODO - This should try to increase the input q and perhaps the width # Unlikely to happen raise NotImplementedError() act_o_q.q = force_out.q act_qrec = SymmetricQuantizationRecord(in_qs=conv_qrec.out_qs, out_qs=[act_o_q]) result[NodeId(node, act_node)] = act_qrec return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=act_qrec.out_qs), result else: return self.default_quantize_fusion(G, node, in_qs, force_out=force_out)
def default_quantize_fusion(self, G: NNGraph, node: ConvFusionParameters, in_qs, force_out=None) -> SymmetricQuantizationRecord: del G result = OrderedDict() fin_qs = in_qs for fnode in node.contained_nodes(): qrec = self.calculate_q(fnode, self._activation_stats.get( NodeId(node, fnode)), fin_qs, self._force_width, force_out=force_out) result[NodeId(node, fnode)] = qrec fin_qs = qrec.out_qs return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=fin_qs), result
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for conv_node in [ params for params in G.nodes() if isinstance(params, Conv2DParameters) ]: node_list = self.get_node_list(G, conv_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': # NOTE: This is only for old POW2 kernels - SQ8 can handle this if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, in_dims=deepcopy(node_list.conv.in_dims), out_dims=deepcopy(node_list.order[-1].out_dims), input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: # TODO - stats prec = QRec.copy_ktype(qrecs[0], in_qs=deepcopy(qrecs[0].in_qs), out_qs=deepcopy(qrecs[-1].out_qs)) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph