def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def match(self, G: GraphView, set_identity: bool = True): has_modified = False for node in G.nodes(node_classes=ConstantInputParameters): out_edges = G.out_edges(node.name) if len(out_edges) <= 1: continue has_modified = True LOG.info( 'node %s has more than one out edge and will be duplicated', node.name) idx = 1 for out_edge in out_edges[1::]: new_constant = ConstantInputParameters(f'{node.name}_{idx}', dims=Dim.unnamed( node.dims.shape), value=node.value.copy()) G.remove_edge(out_edge) G.add_edge( NNEdge(from_node=new_constant, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) idx += 1 if set_identity: self.set_identity(G) return has_modified
def match(self, G: GraphView, set_identity: bool = True): split_nodes = [ node for node in G.nodes() if isinstance(node, SplitParameters) ] has_modified_graph = False for node in split_nodes: # traverse reshapes or transposes that do nothing - check gen # find edges connected to concats res = self.find_split_concat(G, node) if res is None: continue # TODO(martin) - group edges that have adjacent inputs and outputs if G.quantization: qrec = G.quantization[NodeId(node)] for idx, bundle in enumerate(res): if not bundle: continue has_modified_graph = True copy_node = CopyParameters("%s_copy_%s" % (node.name, idx)) for edge_set in bundle: first_edge = edge_set[0] G.remove_edge(first_edge) G.add_edge( NNEdge(copy_node, first_edge.to_node, to_idx=first_edge.to_idx)) G.add_edge(NNEdge(node, copy_node, from_idx=idx)) if G.quantization: G.quantization[NodeId(copy_node)] = qrec.__class__( in_qs=deepcopy(qrec.out_qs[idx]), out_qs=deepcopy(qrec.out_qs[idx])) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes() if self.node_does_nothing(G, node) ]: has_modified_graph = True in_edge = G.in_edges(node.name)[0] G.remove_edge(in_edge) for out_edge in G.out_edges(node.name): G.remove_edge(out_edge) G.add_edge( NNEdge(in_edge.from_node, out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) LOG.info(f'removing {node.name} that does nothing') G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): if G.quantization: LOG.warning( 'match_duplicate_operations does not handle quantized graphs') return False def same_source_edge_fn(x): return f"{x.from_node.__hash__()}##{x.from_idx}" def same_dest_edge(x): return f"{x.to_node.__hash__()}##{x.to_idx}" modified_graph = False while True: found_more = False same_source_edges = [ list(edge_list) for _, edge_list in groupby( sorted(G.edges(), key=same_source_edge_fn), same_source_edge_fn) ] # all have the same origin same_source_edges = [ elem for elem in same_source_edges if len(elem) > 1 ] same_dest_edges = [] same_dest_group_edges = [] for same_source_edge in same_source_edges: same_source_edge = [ edge for edge in same_source_edge if isinstance(edge.to_node, ComparableParameters) ] while same_source_edge: first = same_source_edge.pop(0) others = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.is_same_operation_as(G, x.to_node), first), same_source_edge)) if others: same_dest_edges.append(tuple([first] + others)) for other in others: same_source_edge.remove(other) continue other_groups = list( filter( partial( lambda x, y: x.to_node != y.to_node and y. to_node.can_be_grouped_with(x.to_node), first), same_source_edge)) if other_groups: same_dest_group_edges.append( tuple([first] + other_groups)) for other in other_groups: same_source_edge.remove(other) # all are multiple edges that go to something comparable save_same_dest_edges = same_dest_edges.copy() while same_dest_edges: edge_set = same_dest_edges.pop(0) keep_node = edge_set[0].to_node other_edge_sets = [ edges for edges in same_dest_edges if any(edge.to_node == keep_node for edge in edges) ] for other_edge_set in other_edge_sets: same_dest_edges.remove(other_edge_set) nodes_to_delete = set() for edge_set in [edge_set] + other_edge_sets: for edge in edge_set: other_node = edge.to_node if other_node == keep_node or other_node in nodes_to_delete: continue nodes_to_delete.add(other_node) for out_edge in G.out_edges(other_node): G.add_edge( NNEdge(from_node=keep_node, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) LOG.info( f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}' ) for node in nodes_to_delete: G.remove(node) # # all are multiple edges that go to something comparable # for edge_set in same_dest_edges: # modified_graph = True # found_more = True # first = edge_set[0] # first_node = first.to_node # dup_nodes = [] # for other in edge_set[1::]: # dest_node = other.to_node # dup_nodes.append(dest_node.name) # out_edges = G.out_edges(dest_node.name) # G.remove(dest_node) # for out_edge in out_edges: # G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node, # from_idx=out_edge.from_idx, to_idx=out_edge.to_idx)) # LOG.info( # f'removed duplicates {",".join(dup_nodes)} to {first_node.name}') for edge_set in same_dest_group_edges: modified_graph = True found_more = True # we will merge all the convolutions into one first = edge_set[0] first_node = first.to_node in_edges = G.indexed_in_edges(first_node.name) first_filter = first_node.filter weights_node = in_edges[1].from_node biases_node = in_edges[2].from_node dup_nodes = [] num_convs = len(edge_set) out_shape = deepcopy(first_node.out_dims[0]) out_shape.c *= num_convs # create a split after the first node splitting on channel axis act_slices, out_shapes, axis = SplitParameters.get_splits( out_shape, out_shape.get_order_idx('c'), num_splits=num_convs) split1 = SplitParameters( G.unique_name(f'{first_node.name}_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) out_num = 0 # first node out edge goes to split out_edges = G.out_edges(first_node.name) for edge in out_edges: G.remove_edge(edge) G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_edge(NNEdge(from_node=first_node, to_node=split1)) # first split output goes to original output for other in edge_set[1::]: out_num += 1 node_other = other.to_node dup_nodes.append(node_other.name) in_edges = G.indexed_in_edges(node_other.name) weights_other = in_edges[1].from_node biases_other = in_edges[2].from_node # merge the weights and biases diwn output channel weights_node.value = np.concatenate( (weights_node.value, weights_other.value), axis=first_filter.get_order_idx('out_c')) weights_node.dims = Dim.unnamed(weights_node.value.shape) biases_node.value = np.concatenate( (biases_node.value, biases_other.value)) biases_node.dims = Dim.unnamed(biases_node.value.shape) first_filter.out_c += node_other.filter.out_c # wire edge from split out_edges = G.out_edges(node_other.name) G.remove(node_other) G.remove(weights_other) G.remove(biases_other) for edge in out_edges: G.add_edge( NNEdge(from_node=split1, from_idx=out_num, to_node=edge.to_node, to_idx=edge.to_idx)) LOG.info( f'merged convolutions {",".join(dup_nodes)} into {first_node.name}' ) if not found_more: break if set_identity: self.set_identity(G) return modified_graph
def match(self, G: GraphView, set_identity: bool = True) -> bool: has_modified_graph = False for pad_params in [ pad for pad in G.nodes() if isinstance(pad, PadParameters) ]: pad_in_edges = G.in_edges(pad_params.name) pad_out_edges = G.out_edges(pad_params.name) dont_delete = False for pad_out_edge in pad_out_edges: filter_like_node, is_1d = self.find_conv( G, pad_out_edge.to_node) if not filter_like_node: dont_delete = True continue if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[ 0]: raise ValueError( f"filter {filter_like_node.name} doesn't have a input hint" ) in_hint = filter_like_node.in_dims_hint[0] if is_1d: if len(pad_params.padding) != 2: LOG.warning( "pad node %s is applied to 1d convolution but has length %s", pad_params.name, len(pad_params.padding)) dont_delete = True continue expanded_padding = [ pad_params.padding[0], (0, 0), pad_params.padding[1] ] else: if len(pad_params.padding) != 3: LOG.warning( "pad node %s is applied to 2d convolution but has length %s", pad_params.name, len(pad_params.padding)) dont_delete = True continue expanded_padding = pad_params.padding hinted_pad = { in_hint[idx]: pad for idx, pad in enumerate(expanded_padding) if sum(pad) > 0 } key_set = set(hinted_pad.keys()) key_set -= set(['h', 'w']) if len(key_set) > 0: dont_delete = True LOG.error( "node %s has padding on axes %s and cannot be fused with filter %s", pad_params.name, key_set, filter_like_node.name) continue if any(pval != 0 for val in pad_params.pad_vals for pval in val): dont_delete = True LOG.error( "node %s has non zero pad values and cannot be fused with filter %s", pad_params.name, filter_like_node.name) continue LOG.info("adding padding from: %s to %s filter: %s", pad_params.name, is_1d and "1D" or "2D", filter_like_node.name) for key in ['h', 'w']: if key not in hinted_pad: hinted_pad[key] = (0, 0) filter_like_node.padding = PadDim(*(list(hinted_pad['h']) + list(hinted_pad['w']))) filter_like_node.pad_type = "zero" has_modified_graph = True G.remove_edge(pad_out_edge) if is_1d: reshape_node = pad_out_edge.to_node reshape_node.old_shape = self.remove_padding( reshape_node.old_shape, pad_params.padding) reshape_node.shape = self.remove_padding( reshape_node.shape, expanded_padding) for in_edge in pad_in_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=pad_out_edge.to_node, from_idx=in_edge.from_idx, to_idx=pad_out_edge.to_idx)) if not dont_delete: G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for pad_params in [pad for pad in G.nodes() if isinstance(pad, PadParameters)]: pad_in_edges = G.in_edges(pad_params.name) pad_out_edges = G.out_edges(pad_params.name) dont_delete = False if len(pad_in_edges) == 1 and all(sum(padding) == 0 for padding in pad_params.padding): LOG.info("removing zero padding node %s", pad_params.name) G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) dont_delete = True in_edge = pad_in_edges[0] for out_edge in pad_out_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) else: for pad_out_edge in pad_out_edges: filter_like_node, expanded_padding, reshapes = self.find_conv( G, pad_out_edge.to_node, pad_params.padding) if not filter_like_node: dont_delete = True continue if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[0]: raise ValueError( f"filter {filter_like_node.name} doesn't have a input hint") in_hint = filter_like_node.in_dims_hint[0] hinted_pad = {in_hint[idx]: pad for idx, pad in enumerate(expanded_padding) if sum(pad) > 0} key_set = set(hinted_pad.keys()) key_set -= set(['h', 'w']) if len(key_set) > 0: dont_delete = True LOG.error("node %s has padding on axes %s and cannot be fused with filter %s", pad_params.name, key_set, filter_like_node.name) continue if any(pval != 0 for val in pad_params.pad_vals for pval in val): dont_delete = True LOG.error("node %s has non zero pad values and cannot be fused with filter %s", pad_params.name, filter_like_node.name) continue LOG.info("adding padding from: %s to filter: %s - has %s reshapes", pad_params.name, filter_like_node.name, len(reshapes)) for key in ['h', 'w']: if key not in hinted_pad: hinted_pad[key] = (0, 0) filter_like_node.padding = PadDim( *(list(hinted_pad['h']) + list(hinted_pad['w']))) filter_like_node.pad_type = "zero" has_modified_graph = True G.remove_edge(pad_out_edge) for reshape_node, old_padding, new_padding in reshapes: reshape_node.old_shape = self.remove_padding( reshape_node.old_shape, old_padding) reshape_node.shape = self.remove_padding( reshape_node.shape, new_padding) for in_edge in pad_in_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, to_node=pad_out_edge.to_node, from_idx=in_edge.from_idx, to_idx=pad_out_edge.to_idx)) if not dont_delete: G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False concats = set(G.nodes(node_classes=ConcatParameters)) while concats: concat = concats.pop() if concat.axis != 0: continue subgraph = find_concats_up(G, concat) found = set(subgraph.nodes(node_classes=ConcatParameters)) if len(found) <= 1: continue LOG.info( f"Combining concats {','.join([node.name for node in found])}") modified_graph = True concats -= found in_edges = [inp.edge for inp in subgraph.inputs()] in_dims = [ edge.from_node.out_dims[edge.from_idx] for edge in in_edges ] nodes_to_remove = [ node for node in subgraph.nodes() if node != concat and not isinstance(node, DummyInput) ] for edge in in_edges: G.remove_edge(edge) for node in nodes_to_remove: if node.name in G: G.remove(node) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] # remove_internal_graph(G, subgraph) out_dim = concat.out_dims[0] in_qs = [] for idx, edge in enumerate(in_edges): from_node = edge.from_node from_idx = edge.from_idx if len(in_dims[idx]) > 1: reshape = ReshapeParameters( G.unique_name(f'{concat.name}_flat{idx}'), old_shape=in_dims[idx], shape=Dim.unnamed([in_dims[idx].size()])) G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=reshape)) from_node = reshape from_idx = 0 G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=concat, to_idx=idx)) if in_qs is not None and G.quantization: nid = NodeId(edge.from_node) if nid in G.quantization: qrec = G.quantization[nid] in_qs.append(qrec.out_qs[edge.from_idx]) else: in_qs = None else: in_qs = None if in_qs is not None and G.quantization: nid = NodeId(concat) if nid in G.quantization: G.quantization[nid].in_qs = in_qs reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'), old_shape=Dim.unnamed([out_dim.size() ]), shape=out_dim) G.insert_node_after(concat, reshape, edge_class=NNEdge) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: edge_groups = [] for node in G.nodes(node_classes=SplitParameters): cur_group = None for out_edge_bundle in G.indexed_out_edges(node): if len(out_edge_bundle) == 1: out_edge = out_edge_bundle[0] concat_node_edges = search_down(G, out_edge, ConcatParameters, can_pass=(CopyParameters, NoOPParameters)) if concat_node_edges: if cur_group: this_concat_edge = concat_node_edges[-1] last_concat_edge = cur_group[-1][-1] if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1: cur_group.append(concat_node_edges) continue if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = [concat_node_edges] continue if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None # we leave the splits and concats after this since they will be cleared up by remove_noops for edge_group in edge_groups: split_node = edge_group[0][0].from_node concat_node = edge_group[0][-1].to_node from_idx = edge_group[0][0].from_idx to_idx = edge_group[-1][0].from_idx LOG.info( f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}" ) # combine slices and shapes on edges in group new_slice, new_shape = reduce_slices( split_node.act_slices[from_idx:to_idx + 1], split_node.out_shapes[from_idx:to_idx + 1]) split_node.act_slices = split_node.act_slices[:from_idx] + [ new_slice ] + split_node.act_slices[to_idx + 1:] split_node.out_shapes = split_node.out_shapes[:from_idx] + [ new_shape ] + split_node.out_shapes[to_idx + 1:] # remove all edges and intermediate nodes on all edge groups except the first for edge_list in edge_group[1:]: remove_edges(G, edge_list) out_edge_bundles = G.indexed_out_edges(split_node) # move edges beyond the edge group after the first index for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]): assert len(edge_list) == 1 edge = edge_list[0] G.remove_edge(edge) G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset)) # reindex the in edges in the concat from_idx = edge_group[0][-1].to_idx to_idx = edge_group[-1][-1].to_idx in_edges = G.indexed_in_edges(concat_node) for offset, in_edge in enumerate(in_edges[to_idx + 1:]): G.remove_edge(in_edge) G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset)) return bool(edge_groups)