def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=MatMulOpParameters): in_edges = [edge for edge in G.indexed_in_edges(node.name)] trans_node = in_edges[1].from_node if not isinstance(trans_node, TransposeParameters): continue if isinstance(node, MatMulTransposedParameters): new_node = MatMulOpParameters(node.name) else: new_node = MatMulTransposedParameters(node.name) in_trans_edge = [ edge for edge in G.indexed_in_edges(trans_node.name) ][0] G.replace_node(node.name, new_node) G.remove(trans_node) G.add_edge( NNEdge(in_trans_edge.from_node, new_node, from_idx=in_trans_edge.from_idx, to_idx=1)) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) fragment.add_node(MatScaleNodeMatch()) has_modified_graph = False for frag, match in fragment.match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match['inputs'] ] matched_node = list(frag.nodes())[0] out_edges = G.out_edges(matched_node.name) has_modified_graph = True G.remove(matched_node) fnode = MatScaleFusionParameters( "{}_fusion".format(matched_node.name), fusion_type=match['type'], subgraph=frag, input_mapping=[[(matched_node, 0)], [(matched_node, 1)]]) G.add_node(fnode) for idx, edge in enumerate(match_edges): edge.to_node = fnode edge.to_idx = idx G.add_edge(edge) for edge in out_edges: edge.from_node = fnode G.add_edge(edge) if set_identity: self.set_identity(G) return has_modified_graph
def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False filt_nodes = [node for node in G.nodes() if isinstance(node, (Conv2DParameters, ConvFusionParameters))] for filt_node in filt_nodes: pnode = filt_node if isinstance(filt_node, ConvFusionParameters): cnodes = filt_node.contained_nodes() filt_node = cnodes[0] if not isinstance(filt_node, Conv2DParameters): continue in_dim = filt_node.in_dims filt_dim = filt_node.filter if filt_dim.h <= in_dim[0].h and filt_dim.w <= in_dim[0].w: continue min_h = min(filt_dim.h, in_dim[0].h) min_w = min(filt_dim.w, in_dim[0].w) if min_h > 1 and min_w > 1: LOG.warning("Filter of %s [%dx%d] bigger than input [%dx%d] not optimal but will work on AT", filt_node.name, filt_dim.h, filt_dim.w, in_dim[0].h, in_dim[0].w) continue ker_h = 1 if min_h == 1 else filt_dim.h ker_w = 1 if min_w == 1 else filt_dim.w if ker_h == filt_dim.h and ker_w == filt_dim.w: continue new_filt_dim = Conv2DFilterDim( ker_h, ker_w, filt_dim.out_c, in_c=filt_dim.in_c) LOG.warning("Converting filter of %s from [%dx%d] -> [%dx%d]", filt_node.name, filt_dim.h, filt_dim.w, new_filt_dim.h, new_filt_dim.w) filt_node.filter = new_filt_dim new_w_idxs = [] for dim in filt_dim.order: if dim in ('out_c', 'in_c'): new_w_idxs.append(slice(None)) elif dim == 'h': if new_filt_dim.h == 1: new_w_idxs.append( slice(filt_node.padding.t, filt_node.padding.t + 1)) else: new_w_idxs.append(slice(0, new_filt_dim.h)) elif dim == 'w': if new_filt_dim.w == 1: new_w_idxs.append( slice(filt_node.padding.l, filt_node.padding.l + 1)) else: new_w_idxs.append(slice(0, new_filt_dim.w)) weights_node = G.indexed_in_edges(pnode.name)[1].from_node weights_node.value = weights_node.value[tuple(new_w_idxs)] weights_node.dims = Dim.unnamed(weights_node.value.shape) something_changed = True if set_identity: self.set_identity(G) return something_changed
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for split_node in set( [node for node in G.nodes() if isinstance(node, SplitParameters)]): in_edges = G.in_edges(split_node.name) if len(in_edges) > 1: continue in_edge = in_edges[0] if not isinstance(in_edge.from_node, ConcatParameters): continue concat_node = in_edge.from_node if len(G.out_edges(concat_node.name)) > 1: continue if concat_node.transpose_out or split_node.transpose_in: continue if concat_node.axis != split_node.axis: continue axis = concat_node.axis split_out_sizes = [ out_shape[axis] for out_shape in split_node.out_shapes ] if len(split_out_sizes) != len(concat_node.in_dims): continue if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)): continue has_modified_graph = True LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name) concat_in_edges = G.indexed_in_edges(concat_node.name) split_out_edges = G.indexed_out_edges(split_node.name) G.remove(split_node) G.remove(concat_node) for idx, in_edge in enumerate(concat_in_edges): for out_edge in split_out_edges[idx]: G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fac = MatScalePairMatchFactory() has_modified_graph = False for frag, match in fac.get_matcher().match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match ] first_node = frag.inputs()[0] last_node = frag.outputs()[0] out_edges = G.out_edges(last_node.name) for node in frag.nodes(): G.remove(node) input_mapping = MatScaleFusionParameters.get_mapping_from_edges( match_edges) fnode = MatScaleFusionParameters( "{}_{}_fusion".format(first_node.name, last_node.name), fusion_type="vec_scalar", subgraph=frag, input_mapping=MatScaleFusionParameters.convert_input_mapping( input_mapping)) has_modified_graph = True G.add_node(fnode) fnode.in_dims_hint = [None] * 3 for idx, edge in enumerate(match_edges): new_edge = edge.clone( to_node=fnode, to_idx=list(input_mapping[edge.to_node].keys())[0]) if new_edge.from_node.out_dims_hint: fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[ edge.from_idx] G.add_edge(new_edge) for edge in out_edges: new_edge = edge.clone(from_node=fnode) G.add_edge(new_edge) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))] need_a_copy_edges = [] for node in candidates: for idx, edge in enumerate(G.indexed_in_edges(node.name)): real_from_node, _ = find_real_in_edge(G, edge) if isinstance(real_from_node, (InputParameters, ConstantInputParameters)): need_a_copy_edges.append((edge, idx)) has_modified_graph = False for edge in need_a_copy_edges: LOG.info( "Insert copy on split input %s", edge[0].to_node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy')) G.insert_node_at_edge(cnode, edge[0]) if G.quantization: G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges( node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False filter_nodes = [ node for node in G.nodes() if isinstance(node, FilterParameters) ] for filter_node in filter_nodes: while True: out_edges = G.out_edges(filter_node.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node # must be a valid matrix op if not isinstance(op_node, tuple(OPS.keys())): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_c = filter_node.filter.out_c op, weights_and_biases = OPS[op_node.__class__] # it would be possible to support mult bias addition by out channel but only supporting a # scalar at present if len(flat_value) != 1 and (weights_and_biases or len(flat_value) != out_c): LOG.warning('could not absorb %s into %s', const_node.name, filter_node.name) break # If there is quantization then essentially the output of the filter # takes the quantization of the output of the operation. # The biases will not change since their quantization depends on the weights # and input fnid = NodeId(filter_node) opnid = NodeId(op_node) if G.quantization and (fnid in G.quantization or opnid in G.quantization): if not (fnid in G.quantization and opnid in G.quantization): LOG.warning( 'could not absorb %s into %s - graph is partially quantized', const_node.name, filter_node.name) break fqrec = G.quantization[fnid] opqrec = G.quantization[opnid] fqrec.out_qs[0] = opqrec.out_qs[0] has_modified_graph = True LOG.info("fusing bias in %s into %s", const_node.name, filter_node.name) self.fuse_bias(G, filter_node, other_idx, op, flat_value, 2) if weights_and_biases: # TODO - need to adjust weights quantization here LOG.info("fusing multiplicative bias in %s into %s", const_node.name, filter_node.name) self.fuse_bias(G, filter_node, other_idx, op, flat_value, 1) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) for edge in out_edges: G.add_edge( NNEdge(from_node=filter_node, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False has_transposed = False for params in G.nodes(node_classes=MatMulOpParameters): while True: out_edges = G.out_edges(params.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node # must be a valid matrix op if not isinstance(op_node, (MatrixAddParameters, MatrixMulParameters)): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_shape = params.out_dims[0].shape if len(out_shape) != 2: raise ValueError( f'strange outputs shape of {out_shape} for matmul {params.name}' ) if len(flat_value) != out_shape[0] and len( flat_value) != out_shape[1]: LOG.info( "can't fuse %s into %s - value shape is not correct for bias", const_node.name, params.name) break has_bias = len(params.in_dims) == 3 if isinstance(op_node, MatrixAddParameters): if has_bias: if len(flat_value.shape) != len(params.in_dims[2]): LOG.info( "can't fuse %s into %s - bias shape is not the same", const_node.name, params.name) break bias_node = G.indexed_in_edges( params.name)[2].from_node LOG.info( "folding additive bias from %s into existing bias on %s", op_node.name, params.name) bias_node.value = bias_node.dq_value + flat_value else: if len(flat_value) == out_shape[1]: # matmul needs to be transposed to fuse this reverse_matmul(G, params) has_transposed = True bias_node = ConstantInputParameters( G.unique_name(f'{params.name}_bias'), value=flat_value, dims=Dim.unnamed(flat_value.shape)) G.add_edge( NNEdge(from_node=bias_node, to_node=params, to_idx=2)) # extend the inward transpose if params.transpose_in: params.transpose_in = params.transpose_in + [None] LOG.info( "folding additive bias from %s into new bias on %s", op_node.name, params.name) else: params_in = G.indexed_in_edges(params.name) consts = [ isinstance(edge.from_node, ConstantInputParameters) for edge in params_in ] if not any(consts): break mult_const_node = params_in[1].from_node if consts[ 1] else params_in[0].from_node mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue if has_bias: bias_node = params_in[2].from_node bias_node.value = bias_node.dqvalue * const_node.dqvalue LOG.info( "folding multaplicative bias from %s into new bias on %s", op_node.name, params.name) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) for edge in out_edges: G.add_edge( NNEdge(from_node=params, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_dimensions() if G.quantization: quantizer = UnifiedQuantizer.from_quantized_graph(G) quantizer.quantize(G, start_nodes=[params]) RemoveUnnecessaryQuantizeOperators().match(G) if has_transposed: G.adjust_order() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): if G.quantization: LOG.warning( 'match_duplicate_operations does not handle quantized graphs') return False def same_source_edge_fn(x): return f"{x.from_node.__hash__()}##{x.from_idx}" def same_dest_edge(x): return f"{x.to_node.__hash__()}##{x.to_idx}" modified_graph = False while True: found_more = False same_source_edges = [ list(edge_list) for _, edge_list in groupby( sorted(G.edges(), key=same_source_edge_fn), same_source_edge_fn) ] # all have the same origin same_source_edges = [ elem for elem in same_source_edges if len(elem) > 1 ] same_dest_edges = [] same_dest_group_edges = [] for same_source_edge in same_source_edges: same_source_edge = [ edge for edge in same_source_edge if isinstance(edge.to_node, ComparableParameters) ] while same_source_edge: first = same_source_edge.pop(0) others = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.is_same_operation_as(G, x.to_node), first), same_source_edge)) if others: same_dest_edges.append(tuple([first] + others)) for other in others: same_source_edge.remove(other) continue other_groups = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.can_be_grouped_with(x.to_node), first), same_source_edge)) if other_groups: same_dest_group_edges.append( tuple([first] + other_groups)) for other in other_groups: same_source_edge.remove(other) # all are multiple edges that go to something comparable save_same_dest_edges = same_dest_edges.copy() while same_dest_edges: edge_set = same_dest_edges.pop(0) keep_node = edge_set[0].to_node other_edge_sets = [ edges for edges in same_dest_edges if any(edge.to_node == keep_node for edge in edges) ] for other_edge_set in other_edge_sets: same_dest_edges.remove(other_edge_set) nodes_to_delete = set() for edge_set in [edge_set] + other_edge_sets: for edge in edge_set: other_node = edge.to_node if other_node == keep_node or other_node in nodes_to_delete: continue nodes_to_delete.add(other_node) for out_edge in G.out_edges(other_node): G.add_edge( NNEdge(from_node=keep_node, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) LOG.info( f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}' ) for node in nodes_to_delete: G.remove(node) # # all are multiple edges that go to something comparable # for edge_set in same_dest_edges: # modified_graph = True # found_more = True # first = edge_set[0] # first_node = first.to_node # dup_nodes = [] # for other in edge_set[1::]: # dest_node = other.to_node # dup_nodes.append(dest_node.name) # out_edges = G.out_edges(dest_node.name) # G.remove(dest_node) # for out_edge in out_edges: # G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node, # from_idx=out_edge.from_idx, to_idx=out_edge.to_idx)) # LOG.info( # f'removed duplicates {",".join(dup_nodes)} to {first_node.name}') for edge_set in same_dest_group_edges: modified_graph = True found_more = True # we will merge all the convolutions into one first = edge_set[0] first_node = first.to_node in_edges = G.indexed_in_edges(first_node.name) first_filter = first_node.filter weights_node = in_edges[1].from_node biases_node = in_edges[2].from_node dup_nodes = [] num_convs = len(edge_set) out_shape = deepcopy(first_node.out_dims[0]) out_shape.c *= num_convs # create a split after the first node splitting on channel axis act_slices, out_shapes, axis = SplitParameters.get_splits( out_shape, out_shape.get_order_idx('c'), num_splits=num_convs) split1 = SplitParameters( G.unique_name(f'{first_node.name}_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) out_num = 0 # first node out edge goes to split out_edges = G.out_edges(first_node.name) for edge in out_edges: G.remove_edge(edge) G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_edge(NNEdge(from_node=first_node, to_node=split1)) # first split output goes to original output for other in edge_set[1::]: out_num += 1 node_other = other.to_node dup_nodes.append(node_other.name) in_edges = G.indexed_in_edges(node_other.name) weights_other = in_edges[1].from_node biases_other = in_edges[2].from_node # merge the weights and biases diwn output channel weights_node.value = np.concatenate( (weights_node.value, weights_other.value), axis=first_filter.get_order_idx('out_c')) weights_node.dims = Dim.unnamed(weights_node.value.shape) biases_node.value = np.concatenate( (biases_node.value, biases_other.value)) biases_node.dims = Dim.unnamed(biases_node.value.shape) first_filter.out_c += node_other.filter.out_c # wire edge from split out_edges = G.out_edges(node_other.name) G.remove(node_other) G.remove(weights_other) G.remove(biases_other) for edge in out_edges: G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) LOG.info( f'merged convolutions {",".join(dup_nodes)} into {first_node.name}' ) if not found_more: break if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: edge_groups = [] for node in G.nodes(node_classes=SplitParameters): cur_group = None for out_edge_bundle in G.indexed_out_edges(node): if len(out_edge_bundle) == 1: out_edge = out_edge_bundle[0] concat_node_edges = search_down(G, out_edge, ConcatParameters, can_pass=(CopyParameters, NoOPParameters)) if concat_node_edges: if cur_group: this_concat_edge = concat_node_edges[-1] last_concat_edge = cur_group[-1][-1] if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1: cur_group.append(concat_node_edges) continue if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = [concat_node_edges] continue if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None # we leave the splits and concats after this since they will be cleared up by remove_noops for edge_group in edge_groups: split_node = edge_group[0][0].from_node concat_node = edge_group[0][-1].to_node from_idx = edge_group[0][0].from_idx to_idx = edge_group[-1][0].from_idx LOG.info( f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}" ) # combine slices and shapes on edges in group new_slice, new_shape = reduce_slices( split_node.act_slices[from_idx:to_idx + 1], split_node.out_shapes[from_idx:to_idx + 1]) split_node.act_slices = split_node.act_slices[:from_idx] + [ new_slice ] + split_node.act_slices[to_idx + 1:] split_node.out_shapes = split_node.out_shapes[:from_idx] + [ new_shape ] + split_node.out_shapes[to_idx + 1:] # remove all edges and intermediate nodes on all edge groups except the first for edge_list in edge_group[1:]: remove_edges(G, edge_list) out_edge_bundles = G.indexed_out_edges(split_node) # move edges beyond the edge group after the first index for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]): assert len(edge_list) == 1 edge = edge_list[0] G.remove_edge(edge) G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset)) # reindex the in edges in the concat from_idx = edge_group[0][-1].to_idx to_idx = edge_group[-1][-1].to_idx in_edges = G.indexed_in_edges(concat_node) for offset, in_edge in enumerate(in_edges[to_idx + 1:]): G.remove_edge(in_edge) G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset)) return bool(edge_groups)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) # if there are quantization stats then clear them. They need to be created again G.quantization.stats = None if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + \ G.indexed_in_edges(node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph