def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes(node_classes=StridedSliceParameters) ]: if node.slice_shape != tuple(node.in_dims[0].shape): continue has_modified_graph = True nid = NodeId(node) if node.slice_shape == node.out_shape: LOG.info( f'removing strided slice {node.name} that does nothing') G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization and nid in G.quantization: del G.quantization[nid] else: reshape = ReshapeParameters( G.unique_name(f'{node.name}_reshape'), old_shape=node.slice_shape, shape=node.out_shape) LOG.info( f'replacing strided slice {node.name} with reshape {reshape.name}' ) G.replace_node(node, reshape) if G.quantization and nid in G.quantization: G.quantization[NodeId(reshape)] = G.quantization[nid] del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [ node for node in G.nodes(node_classes=SplitParameters) if search_up_for_input(G, node) ] has_modified_graph = False for node in candidates: LOG.info("Insert copy on split input %s", node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{node.name}_copy')) G.insert_node_at_edge(cnode, G.in_edges(node.name)[0]) if G.quantization: G.quantization.copy_qrec(node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]: out_edges = G.out_edges(relu_node) if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters): continue mul_node = out_edges[0].to_node in_edges = G.in_edges(mul_node) if len(in_edges) != 2: continue other_edge = (set(in_edges) - {out_edges[0]}).pop() constant_node = other_edge.from_node if len(G.out_edges(constant_node)) != 1: continue if (not isinstance(constant_node, ConstantInputParameters) or not check_equals(G, constant_node, 1.0/6.0)): continue something_changed = True activation = HSigmoidActivationParameters( G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0) in_edges = G.in_edges(relu_node) out_edges = G.out_edges(mul_node) nodes_to_replace = [relu_node, mul_node, constant_node] LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}') G.remove_all(nodes_to_replace) for in_edge in in_edges: G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0)) for out_edge in out_edges: G.add_edge(NNEdge.clone( out_edge, from_node=activation, from_idx=0)) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype( reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return something_changed
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))] need_a_copy_edges = [] for node in candidates: for idx, edge in enumerate(G.indexed_in_edges(node.name)): real_from_node, _ = find_real_in_edge(G, edge) if isinstance(real_from_node, (InputParameters, ConstantInputParameters)): need_a_copy_edges.append((edge, idx)) has_modified_graph = False for edge in need_a_copy_edges: LOG.info( "Insert copy on split input %s", edge[0].to_node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy')) G.insert_node_at_edge(cnode, edge[0]) if G.quantization: G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): nodes = list(G.nodes(node_classes=GlobalPoolingParameters)) modified_graph = False while nodes: node = nodes.pop() node_group = self.reductions(G, node) if len(node_group) <= 1: continue modified_graph = True reduction_axes, new_shape, has_keepdims, _ = reduce( reduce_reduction, node_group, None) new_node = node_group[0] new_node.axis = sorted(list(reduction_axes)) new_node.keep_dims = has_keepdims out_edges = G.out_edges(node_group[-1].name) if G.quantization: last_qrec = G.quantization[NodeId(node_group[-1])] G.quantization[NodeId(new_node)].out_qs = last_qrec.out_qs for node in node_group[1::]: G.remove(node.name) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] if has_keepdims and len(new_shape) != len( new_node.in_dims[0].shape): rparams = ReshapeParameters( G.unique_name(f'{new_node.name}_reshape'), shape=Dim.unnamed(new_shape)) if G.quantization: G.quantization.copy_qrec(last_qrec, 'out', 0, rparams) G.add_edge(NNEdge(new_node, rparams)) new_node = rparams for edge in out_edges: G.add_edge(NNEdge(new_node, edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False has_transposed = False for params in G.nodes(node_classes=MatMulOpParameters): while True: out_edges = G.out_edges(params.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node # must be a valid matrix op if not isinstance(op_node, (MatrixAddParameters, MatrixMulParameters)): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_shape = params.out_dims[0].shape if len(out_shape) != 2: raise ValueError( f'strange outputs shape of {out_shape} for matmul {params.name}' ) if len(flat_value) != out_shape[0] and len( flat_value) != out_shape[1]: LOG.info( "can't fuse %s into %s - value shape is not correct for bias", const_node.name, params.name) break has_bias = len(params.in_dims) == 3 if isinstance(op_node, MatrixAddParameters): if has_bias: if len(flat_value.shape) != len(params.in_dims[2]): LOG.info( "can't fuse %s into %s - bias shape is not the same", const_node.name, params.name) break bias_node = G.indexed_in_edges( params.name)[2].from_node LOG.info( "folding additive bias from %s into existing bias on %s", op_node.name, params.name) bias_node.value = bias_node.dq_value + flat_value else: if len(flat_value) == out_shape[1]: # matmul needs to be transposed to fuse this reverse_matmul(G, params) has_transposed = True bias_node = ConstantInputParameters( G.unique_name(f'{params.name}_bias'), value=flat_value, dims=Dim.unnamed(flat_value.shape)) G.add_edge( NNEdge(from_node=bias_node, to_node=params, to_idx=2)) # extend the inward transpose if params.transpose_in: params.transpose_in = params.transpose_in + [None] LOG.info( "folding additive bias from %s into new bias on %s", op_node.name, params.name) else: params_in = G.indexed_in_edges(params.name) consts = [ isinstance(edge.from_node, ConstantInputParameters) for edge in params_in ] if not any(consts): break mult_const_node = params_in[1].from_node if consts[ 1] else params_in[0].from_node mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue if has_bias: bias_node = params_in[2].from_node bias_node.value = bias_node.dqvalue * const_node.dqvalue LOG.info( "folding multaplicative bias from %s into new bias on %s", op_node.name, params.name) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) for edge in out_edges: G.add_edge( NNEdge(from_node=params, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_dimensions() if G.quantization: quantizer = UnifiedQuantizer.from_quantized_graph(G) quantizer.quantize(G, start_nodes=[params]) RemoveUnnecessaryQuantizeOperators().match(G) if has_transposed: G.adjust_order() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] node_sets = self.find_sets(G) for node_set in node_sets: Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for node in node_set: frag.add_node(node) for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in node_set), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } G.quantization[NodeId(expr)] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = UnifiedQuantizer.from_quantized_graph(G) G.quantization = quantizer.quantize(G, start_nodes=to_quantize) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): if G.quantization: LOG.warning( 'match_duplicate_operations does not handle quantized graphs') return False def same_source_edge_fn(x): return f"{x.from_node.__hash__()}##{x.from_idx}" def same_dest_edge(x): return f"{x.to_node.__hash__()}##{x.to_idx}" modified_graph = False while True: found_more = False same_source_edges = [ list(edge_list) for _, edge_list in groupby( sorted(G.edges(), key=same_source_edge_fn), same_source_edge_fn) ] # all have the same origin same_source_edges = [ elem for elem in same_source_edges if len(elem) > 1 ] same_dest_edges = [] same_dest_group_edges = [] for same_source_edge in same_source_edges: same_source_edge = [ edge for edge in same_source_edge if isinstance(edge.to_node, ComparableParameters) ] while same_source_edge: first = same_source_edge.pop(0) others = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.is_same_operation_as(G, x.to_node), first), same_source_edge)) if others: same_dest_edges.append(tuple([first] + others)) for other in others: same_source_edge.remove(other) continue other_groups = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.can_be_grouped_with(x.to_node), first), same_source_edge)) if other_groups: same_dest_group_edges.append( tuple([first] + other_groups)) for other in other_groups: same_source_edge.remove(other) # all are multiple edges that go to something comparable save_same_dest_edges = same_dest_edges.copy() while same_dest_edges: edge_set = same_dest_edges.pop(0) keep_node = edge_set[0].to_node other_edge_sets = [ edges for edges in same_dest_edges if any(edge.to_node == keep_node for edge in edges) ] for other_edge_set in other_edge_sets: same_dest_edges.remove(other_edge_set) nodes_to_delete = set() for edge_set in [edge_set] + other_edge_sets: for edge in edge_set: other_node = edge.to_node if other_node == keep_node or other_node in nodes_to_delete: continue nodes_to_delete.add(other_node) for out_edge in G.out_edges(other_node): G.add_edge( NNEdge(from_node=keep_node, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) LOG.info( f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}' ) for node in nodes_to_delete: G.remove(node) # # all are multiple edges that go to something comparable # for edge_set in same_dest_edges: # modified_graph = True # found_more = True # first = edge_set[0] # first_node = first.to_node # dup_nodes = [] # for other in edge_set[1::]: # dest_node = other.to_node # dup_nodes.append(dest_node.name) # out_edges = G.out_edges(dest_node.name) # G.remove(dest_node) # for out_edge in out_edges: # G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node, # from_idx=out_edge.from_idx, to_idx=out_edge.to_idx)) # LOG.info( # f'removed duplicates {",".join(dup_nodes)} to {first_node.name}') for edge_set in same_dest_group_edges: modified_graph = True found_more = True # we will merge all the convolutions into one first = edge_set[0] first_node = first.to_node in_edges = G.indexed_in_edges(first_node.name) first_filter = first_node.filter weights_node = in_edges[1].from_node biases_node = in_edges[2].from_node dup_nodes = [] num_convs = len(edge_set) out_shape = deepcopy(first_node.out_dims[0]) out_shape.c *= num_convs # create a split after the first node splitting on channel axis act_slices, out_shapes, axis = SplitParameters.get_splits( out_shape, out_shape.get_order_idx('c'), num_splits=num_convs) split1 = SplitParameters( G.unique_name(f'{first_node.name}_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) out_num = 0 # first node out edge goes to split out_edges = G.out_edges(first_node.name) for edge in out_edges: G.remove_edge(edge) G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_edge(NNEdge(from_node=first_node, to_node=split1)) # first split output goes to original output for other in edge_set[1::]: out_num += 1 node_other = other.to_node dup_nodes.append(node_other.name) in_edges = G.indexed_in_edges(node_other.name) weights_other = in_edges[1].from_node biases_other = in_edges[2].from_node # merge the weights and biases diwn output channel weights_node.value = np.concatenate( (weights_node.value, weights_other.value), axis=first_filter.get_order_idx('out_c')) weights_node.dims = Dim.unnamed(weights_node.value.shape) biases_node.value = np.concatenate( (biases_node.value, biases_other.value)) biases_node.dims = Dim.unnamed(biases_node.value.shape) first_filter.out_c += node_other.filter.out_c # wire edge from split out_edges = G.out_edges(node_other.name) G.remove(node_other) G.remove(weights_other) G.remove(biases_other) for edge in out_edges: G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) LOG.info( f'merged convolutions {",".join(dup_nodes)} into {first_node.name}' ) if not found_more: break if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False slices_by_origin = {} for slice_node in [ node for node in G.nodes() if isinstance(node, StridedSliceParameters) ]: in_edge = G.in_edges(slice_node.name)[0] group = slices_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(slice_node) for in_edge, slice_nodes in slices_by_origin.items(): slices = list(zip(*[node.act_slice for node in slice_nodes])) if len(slice_nodes) == 1: self.slice_to_split(G, slice_nodes, slices) continue # strides must be one if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis): continue diff_axes = list([ idx for idx, elems in enumerate(slices) if not all(elems[0] == elem for elem in elems[1::]) ]) not_diff_axes = [ idx for idx in range(len(slices)) if idx not in diff_axes ] diff_slices = [ sl for idx, sl in enumerate(slices) if idx in diff_axes ] axis_lengths = in_edge[0].out_dims[in_edge[1]].shape if not_diff_axes and min(not_diff_axes) < max(diff_axes): transpose_from = tuple(range(len(slices))) transpose_to = tuple(diff_axes + not_diff_axes) axis_lengths = [axis_lengths[idx] for idx in transpose_to] else: transpose_from = transpose_to = None diff_axis_lengths = axis_lengths[0:len(diff_axes):] diff_slices = combine_slices(diff_axis_lengths, diff_slices, slice_nodes) if diff_slices is None: continue if len(diff_axes) > 1: reshape_from = axis_lengths reshape_to = [np.prod(diff_axis_lengths)] + \ axis_lengths[len(diff_axes)::] else: reshape_from = None reshape_to = slice_nodes[0].in_dims[0].shape if transpose_from: reshape_to = [reshape_to[idx] for idx in transpose_to] sizes, shapes, sorted_nodes = slices_to_sizes( diff_slices, axis_lengths[len(diff_axes)::]) name_prefix = sorted_nodes[0].name in_edge = G.in_edges(sorted_nodes[0].name)[0] in_node = in_edge.from_node in_idx = in_edge.from_idx if transpose_from: params = TransposeParameters(G.unique_name(name_prefix + '_tin'), transpose=transpose_to) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 if reshape_from: params = ReshapeParameters(G.unique_name(name_prefix + '_reshape'), old_shape=Dim.unnamed(reshape_from), shape=Dim.unnamed(reshape_to)) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 act_slices, out_shapes, axis = SplitParameters.get_splits( reshape_to, 0, splits=sizes) split_node = SplitParameters(G.unique_name(name_prefix + '_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) G.add_edge( NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node)) sub_names = [] for idx, node in enumerate(sorted_nodes): sub_names.append(node.name) out_edges = G.out_edges(node.name) G.remove(node) for out_edge in out_edges: params = split_node out_idx = idx if reshape_from: from_node = params params = ReshapeParameters( G.unique_name(name_prefix + f'_reshape{idx}'), shape=Dim.unnamed(shapes[idx])) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 if transpose_from: from_node = params params = TransposeParameters( G.unique_name(name_prefix + f'_tout{idx}'), transpose=reverse_transpose(transpose_to)) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 G.add_edge( NNEdge(from_node=params, to_node=out_edge.to_node, from_idx=out_idx, to_idx=out_edge.to_idx)) if G.quantization: G.add_dimensions() quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() RemoveUnnecessaryQuantizeOperators().match(G) LOG.info( f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}' ) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False concats = set(G.nodes(node_classes=ConcatParameters)) while concats: concat = concats.pop() if concat.axis != 0: continue subgraph = find_concats_up(G, concat) found = set(subgraph.nodes(node_classes=ConcatParameters)) if len(found) <= 1: continue LOG.info( f"Combining concats {','.join([node.name for node in found])}") modified_graph = True concats -= found in_edges = [inp.edge for inp in subgraph.inputs()] in_dims = [ edge.from_node.out_dims[edge.from_idx] for edge in in_edges ] nodes_to_remove = [ node for node in subgraph.nodes() if node != concat and not isinstance(node, DummyInput) ] for edge in in_edges: G.remove_edge(edge) for node in nodes_to_remove: if node.name in G: G.remove(node) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] # remove_internal_graph(G, subgraph) out_dim = concat.out_dims[0] in_qs = [] for idx, edge in enumerate(in_edges): from_node = edge.from_node from_idx = edge.from_idx if len(in_dims[idx]) > 1: reshape = ReshapeParameters( G.unique_name(f'{concat.name}_flat{idx}'), old_shape=in_dims[idx], shape=Dim.unnamed([in_dims[idx].size()])) G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=reshape)) from_node = reshape from_idx = 0 G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=concat, to_idx=idx)) if in_qs is not None and G.quantization: nid = NodeId(edge.from_node) if nid in G.quantization: qrec = G.quantization[nid] in_qs.append(qrec.out_qs[edge.from_idx]) else: in_qs = None else: in_qs = None if in_qs is not None and G.quantization: nid = NodeId(concat) if nid in G.quantization: G.quantization[nid].in_qs = in_qs reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'), old_shape=Dim.unnamed([out_dim.size() ]), shape=out_dim) G.insert_node_after(concat, reshape, edge_class=NNEdge) if set_identity: self.set_identity(G) return modified_graph