def calculate_q(self, node, astats, in_qs, force_width, force_out=None): if isinstance(node, (InputParameters, MatrixBroadcastedLinearOpParameters, ConstantInputParameters, MatScaleFusionParameters)): qrec = self.calculate_output_q(node, astats, in_qs, force_width=force_width, force_out=force_out) elif isinstance(node, Conv2DParameters): qrec = self.calculate_filter_q(node, astats, in_q=in_qs[0], force_width=force_width, force_out=force_out) elif isinstance(node, FcParameters): qrec = self.calculate_filter_q(node, astats, in_q=in_qs[0], force_width=force_width, force_out=force_out) elif isinstance(node, SoftMaxParameters): # softmax always outputs Q15 qrec = SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[QType(16, 15, True)]) elif isinstance(node, ActivationParameters): qrec = SymmetricQuantizationRecord( in_qs=in_qs, out_qs=[self.compute_activation_out_qtype(node, in_qs[0])]) else: qrec = SymmetricQuantizationRecord(in_qs=in_qs, out_qs=in_qs) return qrec
def _quantize(cls, params, in_qs, stats, **kwargs): o_q = in_qs[0] force_out_qs, _ = cls.get_pow2_opts(**kwargs) first_forced_q = force_out_qs and next( iter(out_q for out_q in force_out_qs if out_q is not None), None) if first_forced_q and not all(out_q == first_forced_q for out_q in force_out_qs if out_q is not None): LOG.error( 'split %s is being forced to have different output qtypes', params.name) return None if first_forced_q: backwards = kwargs.get('backwards', None) if backwards: # if going backwards and forced then we force our input return SymmetricQuantizationRecord( in_qs=first_forced_q, out_qs=[ deepcopy(first_forced_q) for _ in range(params.num_splits) ]) elif o_q != first_forced_q: LOG.error( 'split %s is being forced to have different output to input', params.name) return None # continue here if forced since o_q == forced_q return SymmetricQuantizationRecord( in_qs=in_qs, out_qs=[deepcopy(o_q) for _ in range(params.num_splits)])
def quantize_fusion(self, G: NNGraph, node: ConvFusionParameters, in_qs, force_out=None) -> SymmetricQuantizationRecord: if node.fusion_type == 'conv_active': result = OrderedDict() nodes = node.contained_nodes() conv_node = nodes[0] conv_astats = self._activation_stats.get(NodeId(node, conv_node)) conv_qrec = self.calculate_filter_q(conv_node, conv_astats, in_q=in_qs[0], force_width=self._force_width, out_as_acc=True) result[NodeId(node, conv_node)] = conv_qrec act_node = nodes[1] act_astats = self._activation_stats.get(NodeId(node, act_node)) if force_out and force_out.bits: act_max_q = self.compute_activation_out_maxq( act_node, force_out.bits) if force_out.q is not None: if (act_max_q is not None and force_out.q > act_max_q ) or force_out.q > conv_qrec.out_qs[0].q: # We cannot shift left in the kernel # TODO - This should try to increase the input q and perhaps the width # Unlikely to happen raise NotImplementedError() act_o_q = QType(bits=force_out.bits, q=force_out.q, signed=True) else: if act_max_q is not None: act_o_q.q = min(act_max_q, act_o_q.q) else: act_o_q = QType.from_min_max( max_val=act_astats['range_out'][0]['max'], min_val=act_astats['range_out'][0]['min'], bits=self._force_width) act_o_q.q = min(act_o_q.q, conv_qrec.out_qs[0].q) if force_out and force_out.q: if force_out.q > act_max_q or force_out.q > conv_qrec.out_qs[ 0].q: # We cannot shift left in the kernel # TODO - This should try to increase the input q and perhaps the width # Unlikely to happen raise NotImplementedError() act_o_q.q = force_out.q act_qrec = SymmetricQuantizationRecord(in_qs=conv_qrec.out_qs, out_qs=[act_o_q]) result[NodeId(node, act_node)] = act_qrec return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=act_qrec.out_qs), result else: return self.default_quantize_fusion(G, node, in_qs, force_out=force_out)
def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, out_dtype = cls.get_pow2_opts(**kwargs) force_out_q = force_out_qs and force_out_qs[0] if params.activation == "relu6": int_bits = calc_bits(6) elif params.activation == "relun": relun = params.activation_params if isinstance(relun, list): relun = max(relun) int_bits = calc_bits(relun) elif params.activation == "relu" or params.activation == "hswish" or params.activation == "hsigmoid" or params.activation == "leaky": int_bits = bits(stats['range_out'][0]['max'], stats['range_out'][0]['min']) else: raise ValueError( f'no support for activation {params.activation} in POW2 quantizer' ) in_q = in_qs[0] if force_out_q is None: q = max(cls.get_pow2_bits(**kwargs) - int_bits, 0) out_q = QType(q=q, dtype=out_dtype) else: if force_out_q.bits - force_out_q.q < int_bits: LOG.warning( 'quantization is forcing node %s to have an output that may clip', params.name) out_q = force_out_q return SymmetricQuantizationRecord(in_qs=[in_q], out_qs=[out_q])
def replace_function(self, G: GraphView, subgraph: GraphView): if not self.validate_match(subgraph): raise DontReplaceError() step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, Conv2DParameters): conv_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def replace_function(self, G: NNGraph, subgraph: GraphView): step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, FcParameters): linear_name = node.name + "_fusion" break LOG.info("fusing nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def replace_function(self, G: NNGraph, subgraph: GraphView): nodes = list(subgraph.nodes()) pnode = ActivationFusion(nodes[0].name + "fusion", nodes[0].op_name + "_active", subgraph) nodes[0].step_idx = 0 nodes[1].step_idx = 1 LOG.debug("fused nodes %s", ",".join((node.name for node in nodes))) if G.quantization: qrecs = G.quantization.get_all(subgraph.nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in subgraph.nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode
def calculate_output_q(self, node: Parameters, astats, in_qs, force_width=None, force_out=None): del node if force_out: if force_out.bits: if force_out.q: o_q = QType(bits=force_out.bits, q=force_out.q, signed=True) else: o_q = QType.from_min_max( max_val=astats['range_out'][0]['max'], min_val=astats['range_out'][0]['min'], bits=force_out.bits) elif force_out.q: o_q = QType.from_min_max(max_val=astats['range_out'][0]['max'], min_val=astats['range_out'][0]['min'], bits=force_width) o_q.q = force_out.q else: o_q = QType.from_min_max(max_val=astats['range_out'][0]['max'], min_val=astats['range_out'][0]['min'], bits=force_width) return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, _ = cls.get_pow2_opts(**kwargs) force_out_q = force_out_qs and force_out_qs[0] out_q = QType.Pow2(16, 15, True) if force_out_q and force_out_q != out_q: return None return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[QType.Pow2(16, 15, True)])
def replace_function(self, G: GraphView, subgraph: GraphView): relu_node = None constant_node = None mul_node = None for node in subgraph.nodes(): if isinstance(node, ReluActivationParameters): relu_node = node elif isinstance(node, ConstantInputParameters): constant_node = node elif isinstance(node, MatrixMulParameters): mul_node = node activation = HSigmoidActivationParameters(mul_node.name + "_fused_close_hsigmoid", offset=0) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] if isinstance(reluqrec, (SymmetricQuantizationRecord)): pqrec = SymmetricQuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) elif isinstance(reluqrec, (MultQuantizationRecord)): pqrec = MultQuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) elif isinstance(reluqrec, (Float32QuantizationRecord)): pqrec = Float32QuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) else: raise NotImplementedError() G.quantization[NodeId(activation)] = pqrec return activation, None, None
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, _ = cls.get_pow2_opts(**kwargs) force_out_q = force_out_qs and force_out_qs[0] backwards = kwargs.get('backwards') # if we are going backwards if backwards: # if output must be forced assert force_out_q, f'going backwards at {params.name} but output is not forced' return SymmetricQuantizationRecord(in_qs=[force_out_q] * len(in_qs), out_qs=[deepcopy(force_out_q)]) # if going forwards and our output is forced and does not match input then # we cannot satisfy if force_out_q and not all(in_q == force_out_q for in_q in in_qs): return None return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[deepcopy(in_qs[0])])
def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, out_dtype = cls.get_pow2_opts(**kwargs) force_out_q = force_out_qs and force_out_qs[0] o_q = QType.from_min_max_pow2(stats['range_out'][0]['min'], stats['range_out'][0]['max'], dtype=out_dtype) if force_out_q: if force_out_q.bits - force_out_q.q < o_q.bits - o_q.q: LOG.warning('%s is being forced to output in Q%s and may clip', params.name, force_out_q.q) o_q = force_out_q return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
def default_quantize_fusion(self, G: NNGraph, node: ConvFusionParameters, in_qs, force_out=None) -> SymmetricQuantizationRecord: del G result = OrderedDict() fin_qs = in_qs for fnode in node.contained_nodes(): qrec = self.calculate_q(fnode, self._activation_stats.get( NodeId(node, fnode)), fin_qs, self._force_width, force_out=force_out) result[NodeId(node, fnode)] = qrec fin_qs = qrec.out_qs return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=fin_qs), result
def calculate_output_q(self, node: Parameters, astats, in_qs, min_qsnr=None, force_width=None, force_out=None): del node if force_out: if force_out.bits: if force_out.q: o_q = QType(bits=force_out.bits, q=force_out.q, signed=True) else: o_q = self.get_quantization(astats, None, force_out.bits) elif force_out.q: o_q = self.get_quantization(astats, min_qsnr, force_width) o_q.q = force_out.q else: o_q = self.get_quantization(astats, min_qsnr, force_width) return SymmetricQuantizationRecord(in_qs=in_qs, out_qs=[o_q])
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges( node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def quantize_backward(self, G: NNGraph, result, edge_recs, node, force_out=None): LOG.debug("quantize backwards %s", node.name) recalculated = False while True: in_qs = self.get_in_qs(G, edge_recs, node) if self.is_filter_node(node): if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) if force_out and force_out.q is not None and qrec.out_qs[ 0].q < force_out.q: if recalculated: raise NotImplementedError( "no quantization solution found") bits_to_gain = force_out.q - qrec.q if bits_to_gain > in_qs[0].q: raise NotImplementedError() # Try to adjust the inputs to satisfy and then # recalculate pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=force_out.bits, q=in_qs[0].q - bits_to_gain, signed=True)) elif isinstance(node, ConcatParameters): assert not recalculated max_width = max(in_q.bits for in_q in in_qs) min_q = min(in_q.q for in_q in in_qs) if force_out: if not self.satisfied(force_out.bits, max_width): max_width = force_out.bits if not self.satisfied(force_out.q, min_q): min_q = force_out.q LOG.debug("normalizing concat to %s", QType(bits=max_width, q=min_q, signed=True)) for pidx, pnode in enumerate( [edge.from_node for edge in G.in_edges(node.name)]): pqrec = in_qs[pidx] if pqrec.q != min_q or pqrec.bits != max_width: self.quantize_backward(G, result, edge_recs, pnode, force_out=QType(bits=max_width, q=min_q, signed=True)) o_q = QType(bits=max_width, q=min_q, signed=True) qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs( G, edge_recs, node), out_qs=[o_q]) elif isinstance(node, SoftMaxParameters): raise NotImplementedError( "softmax kernel cannot change width or q") else: if isinstance(node, ConvFusionParameters): qrec, qrecs = self.quantize_fusion(G, node, in_qs, force_out=force_out) for node_id, fqrec in qrecs.items(): result[node_id] = fqrec else: qrec = self.calculate_q(node, self._activation_stats.get( NodeId(node, None)), in_qs, self._force_width, force_out=force_out) o_q = qrec.out_qs[0] if not (self.satisfied(force_out.q, o_q.q) and self.satisfied(force_out.bits, o_q.bits)): if recalculated: raise NotImplementedError( "no quantization solution found") if len(G.in_edges(node.name)) > 1: raise NotImplementedError( "Nodes with multiple input edges \ need custom handling") pnode = G.in_edges(node.name)[0].from_node self.quantize_backward(G, result, edge_recs, pnode, force_out=force_out) for edges in G.indexed_out_edges(node.name): for edge in edges: edge_recs[edge.params] = qrec.out_qs[edge.from_idx] result[NodeId(node, None)] = qrec o_q = qrec.out_qs[0] if self.satisfied_force(force_out, o_q): break if recalculated: raise NotImplementedError("no quantization solution found") LOG.debug("recalculate %s", node.name) recalculated = True LOG.debug("back complete %s %s", node.name, qrec) return qrec
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for matmul_node in [ params for params in G.nodes() if isinstance(params, MatMulOpParameters) ]: node_list = self.get_node_list(G, matmul_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() if node_list.active is not None: subgraph.add_edge( NNEdge(from_node=node_list.matmul, to_node=node_list.active)) input_mapping = [[(node_list.matmul, idx)] for idx in range(2)] if node_list.add: input_mapping += [[(node_list.matmul, 2)]] output_mapping = [(node_list.active, 0)] if node_list.active else [(node_list.matmul, 0)] pnode = MatMulOpFusionParameters(node_list.matmul.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.matmul.name) if node_list.add: bias_edge = [ add_edge for add_edge in G.in_edges(node_list.add.name) if isinstance(add_edge.from_node, ConstantInputParameters) ][0] out_edges = G.out_edges(node_list.order[-1].name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) if node_list.add: G.add_edge( NNEdge(bias_edge.from_node, pnode, from_idx=bias_edge.from_idx, to_idx=2)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph