def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, PadParameters))) sub.add_node(MatchNode('1', matcher=lambda node:\ isinstance(node, FilterLikeParameters) and\ self.has_no_padding(node))) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node: isinstance(node, FcParameters) and self.valid_linear(node))) sub.add_node(MatchNode('1', matcher=lambda node: isinstance(node, ActivationParameters) and self.valid_activation(node))) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def split_down_from(cur_g, node, res_g=None): """ split cur_g into 2 graphs. Everything from node down and the rest """ if res_g is None: res_g = GraphView() out_edges = cur_g.out_edges(node.name) cur_g.remove(node) if node not in res_g.nodes(): res_g.add_node(node) for edge in out_edges: res_g.add_edge(edge.clone()) split_down_from(cur_g, edge.to_node, res_g=res_g) return res_g
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, FilterParameters))) sub.add_node(MatchNode('1', matcher=lambda node:\ isinstance(node, MatrixAddParameters))) sub.add_node(MatchNode('2', matcher=lambda node:\ isinstance(node, ConstantInputParameters))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def construct_subgraph(G, nodes): """ construct a subgraph from nodes """ sub_g = GraphView() while nodes: node = nodes.pop(0) if node not in sub_g.nodes(): sub_g.add_node(node) for edge in G.out_edges(node.name): if edge.to_node in nodes: sub_g.add_edge(edge.clone()) for edge in G.in_edges(node.name): if edge.from_node in nodes: sub_g.add_edge(edge.clone()) return sub_g
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for node_set in self.find_sets(G): has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.info('matched expression - creating expression %s', self._expr_num) expr = ExpressionFusionParameters(f"expr_{self._expr_num}", subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping) self._expr_num += 1 if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False # collect connected node sets node_sets = group_nodes(G, [ node for node in G.nodes() if isinstance(node, FUSE_NODES) or ( isinstance(node, ConstantInputParameters) and node.out_dims[0].size() == 1) ]) # remove sets that are only ConstantInputs node_sets = [ node_set for node_set in node_sets if not all( isinstance(node, ConstantInputParameters) for node in node_set) ] for node_set in node_sets: has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] out_mapping = list(out_edges.keys()) constant_inputs = [ isinstance(node_edge_idx[0], ConstantInputParameters) for node_edge_idx in in_edges ] expr = ExpressionFusionParameters("expr_%s" % self._expr_num, subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, constant_inputs=constant_inputs) G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=sorted(list(in_edges.keys()), key=lambda x: x[1]), edge_out_mapping=[[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()]) if set_identity: self.set_identity(G) return has_modified_graph
def find_concats_up(G, concat, subgraph: GraphView = None): # Produces a subgraph of concats operating on axis 0 separated by copys or reshapes. # the output node will be the final concat. the input nodes will be all the inputs # to a condensed concat that can replace this subgraph. if subgraph is None: subgraph = GraphView() edge_path = [] for edge in G.indexed_in_edges(concat.name): edge_path = traverse_to_concat(G, edge, subgraph) if edge_path: for inter_edge in edge_path: subgraph.add_edge(inter_edge) else: subgraph.add_edge( NNEdge(from_node=DummyInput( f"{edge.from_node.name}_{edge.from_idx}", edge), to_node=edge.to_node, to_idx=edge.to_idx)) return subgraph
def match_function(self, G: GraphView): sub = GraphView() sub.add_node( MatchNode( '0', matcher=lambda node: isinstance(node, ReluActivationParameters ) and node.upper_bound == 6)) sub.add_node( MatchNode( '1', matcher=lambda node: isinstance(node, MatrixMulParameters))) sub.add_node( MatchNode( '2', matcher=lambda node: isinstance(node, ConstantInputParameters) and check_equals(G, node, 1.0 / 6.0))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, Conv2DParameters) and\ self.valid_activation(node))) if self.match_activation and self.match_pool: if self.pool_after_activation: self.add_activation('1', sub) self.add_pooling('2', sub) else: self.add_pooling('1', sub) self.add_activation('2', sub) sub.add_edge(Edge('0', '1')) sub.add_edge(Edge('1', '2')) elif self.match_activation: self.add_activation('1', sub) sub.add_edge(Edge('0', '1')) elif self.match_pool: self.add_pooling('1', sub) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def construct_subgraph(G, nodes): subg = GraphView() nodes = set(nodes) for node in nodes: for edge in G.out_edges(node.name): # only add internal edges if edge.to_node in nodes: subg.add_edge(NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) def red_fn(state, edge): state.setdefault((edge.from_node, edge.from_idx), [] ).append((edge.to_node, edge.to_idx)) return state inputs = reduce(red_fn, set([edge for node in subg.inputs() for edge in G.in_edges(node.name)]), {}) inputs_map = [] for (fnode, fidx), outs in inputs.items(): inp = FusionInputParameters( f'{fnode.name}_{fidx}_in', dims=fnode.out_dims[fidx]) inputs_map.append((fnode, fidx)) for (tnode, tidx) in outs: subg.add_edge(NNEdge(from_node=inp, to_node=tnode, to_idx=tidx)) outputs = [(node, set(edge.from_idx for edge in G.out_edges(node.name))) for node in subg.outputs()] outputs_map = [] for (node, fidxes) in outputs: output_map = [] outputs_map.append(output_map) for fidx in fidxes: output_map.append((edge.to_node, edge.to_idx) for edge in G.out_edges(node.name) if edge.from_idx == fidx) outp = FusionOutputParameters( f'{node.name}_{fidx}_out', dims=node.out_dims[fidx]) subg.add_edge(NNEdge(from_node=node, to_node=outp, from_idx=fidx)) return (subg, inputs_map, outputs_map)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for conv_node in [ params for params in G.nodes() if isinstance(params, Conv2DParameters) ]: node_list = self.get_node_list(G, conv_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': # NOTE: This is only for old POW2 kernels - SQ8 can handle this if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, in_dims=deepcopy(node_list.conv.in_dims), out_dims=deepcopy(node_list.order[-1].out_dims), input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: # TODO - stats prec = QRec.copy_ktype(qrecs[0], in_qs=deepcopy(qrecs[0].in_qs), out_qs=deepcopy(qrecs[-1].out_qs)) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] node_sets = self.find_sets(G) for node_set in node_sets: Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for node in node_set: frag.add_node(node) for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in node_set), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } G.quantization[NodeId(expr)] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = UnifiedQuantizer.from_quantized_graph(G) G.quantization = quantizer.quantize(G, start_nodes=to_quantize) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) # if there are quantization stats then clear them. They need to be created again G.quantization.stats = None if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + \ G.indexed_in_edges(node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges( node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for matmul_node in [ params for params in G.nodes() if isinstance(params, MatMulOpParameters) ]: node_list = self.get_node_list(G, matmul_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() if node_list.active is not None: subgraph.add_edge( NNEdge(from_node=node_list.matmul, to_node=node_list.active)) input_mapping = [[(node_list.matmul, idx)] for idx in range(2)] if node_list.add: input_mapping += [[(node_list.matmul, 2)]] output_mapping = [(node_list.active, 0)] if node_list.active else [(node_list.matmul, 0)] pnode = MatMulOpFusionParameters(node_list.matmul.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.matmul.name) if node_list.add: bias_edge = [ add_edge for add_edge in G.in_edges(node_list.add.name) if isinstance(add_edge.from_node, ConstantInputParameters) ][0] out_edges = G.out_edges(node_list.order[-1].name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) if node_list.add: G.add_edge( NNEdge(bias_edge.from_node, pnode, from_idx=bias_edge.from_idx, to_idx=2)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL else: valid_activations = VALID_ACTIVATIONS_SQ8 valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL for pool_node in G.nodes(node_classes=(PoolingParameters, GlobalPoolingParameters)): node_list = self.get_node_list(G, pool_node, valid_activations, valid_activations_wo_pool) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.pool, 0)]] output_mapping = [(last_node, 0)] pnode = ActivationFusion(node_list.pool.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) if isinstance(node, GlobalPoolingParameters): # Global pooling fused with activations need to have only the activation scale G.quantization[NodeId(pnode, node)].out_qs[0] = deepcopy( G.quantization[NodeId( pnode, node)].in_qs[0]) G.quantization[NodeId( pnode, node)].out_qs[0].dtype = np.int32 G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.pool.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(NoOPMatcher('0')) return G.match_fragment(sub)