def walk_down(subg: GraphView, node: Parameters, node_slices: Mapping[Parameters, Sequence[SlicedTensor]]): # edges not created if isinstance(node, FusionInputParameters): inp_slice = InputSlice.from_shape(node, node.dims.shape) dim_slices = node_slices[node] = [ SlicedTensor([ SliceElement(tuple([0] * inp_slice.rank), inp_slice.shape, inp_slice) ]) ] else: dim_slices = node_slices.get(node) if dim_slices is None: return # all edges not created if len(dim_slices) < subg.num_in_edges(node.name): return # all edges not created if any(val is None for val in dim_slices): return if isinstance(node, Transposable) and node.transpose_in: for idx, transpose in enumerate(node.transpose_in): if transpose: dim_slices[idx] = dim_slices[idx].transpose(transpose) if isinstance(node, ConcatParameters): dim_slices = [SlicedTensor.concat(*dim_slices, axis=node.axis)] elif isinstance(node, SplitParameters): dim_slices = dim_slices[0].split(node.act_slices) elif isinstance(node, StridedSliceParameters): dim_slices = [dim_slices[0].slice(node.act_slice)] if isinstance(node, Transposable) and node.transpose_out: for idx, transpose in enumerate(node.transpose_out): if transpose: dim_slices[idx] = dim_slices[idx].transpose(transpose) # set output edges for edge_set in subg.indexed_out_edges(node.name): for edge in edge_set: dest_slices = node_slices.setdefault(edge.to_node, [None] * (edge.to_idx + 1)) if len(dest_slices) < edge.to_idx + 1: dest_slices = dest_slices + \ ([None] * ((edge.to_idx + 1) - len(dest_slices))) node_slices[edge.to_node] = dest_slices dest_slices[edge.to_idx] = dim_slices[edge.from_idx] # explore graph for edge_set in subg.indexed_out_edges(node.name): for edge in edge_set: walk_down(subg, edge.to_node, node_slices)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False candidates = [node for node in G.nodes() if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1] while candidates: node = candidates.pop(0) strings = self.explore(G, [node]) if not strings: continue modified_graph = True primary = strings.pop(0) for pnode in primary: if pnode in candidates: candidates.remove(pnode) out_edges = [] for other in strings: out_edges.extend(G.out_edges(other[-1].name)) for other_node in other: if other_node in candidates: candidates.remove(other_node) G.remove(other_node) nid = NodeId(other_node) if G.quantization and nid in G.quantization: del G.quantization[nid] LOG.info( f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}') pend = primary[-1] for edge in out_edges: G.add_edge( NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for split_node in set( [node for node in G.nodes() if isinstance(node, SplitParameters)]): in_edges = G.in_edges(split_node.name) if len(in_edges) > 1: continue in_edge = in_edges[0] if not isinstance(in_edge.from_node, ConcatParameters): continue concat_node = in_edge.from_node if len(G.out_edges(concat_node.name)) > 1: continue if concat_node.transpose_out or split_node.transpose_in: continue if concat_node.axis != split_node.axis: continue axis = concat_node.axis split_out_sizes = [ out_shape[axis] for out_shape in split_node.out_shapes ] if len(split_out_sizes) != len(concat_node.in_dims): continue if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)): continue has_modified_graph = True LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name) concat_in_edges = G.indexed_in_edges(concat_node.name) split_out_edges = G.indexed_out_edges(split_node.name) G.remove(split_node) G.remove(concat_node) for idx, in_edge in enumerate(concat_in_edges): for out_edge in split_out_edges[idx]: G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: edge_groups = [] for node in G.nodes(node_classes=SplitParameters): cur_group = None for out_edge_bundle in G.indexed_out_edges(node): if len(out_edge_bundle) == 1: out_edge = out_edge_bundle[0] concat_node_edges = search_down(G, out_edge, ConcatParameters, can_pass=(CopyParameters, NoOPParameters)) if concat_node_edges: if cur_group: this_concat_edge = concat_node_edges[-1] last_concat_edge = cur_group[-1][-1] if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1: cur_group.append(concat_node_edges) continue if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = [concat_node_edges] continue if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None # we leave the splits and concats after this since they will be cleared up by remove_noops for edge_group in edge_groups: split_node = edge_group[0][0].from_node concat_node = edge_group[0][-1].to_node from_idx = edge_group[0][0].from_idx to_idx = edge_group[-1][0].from_idx LOG.info( f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}" ) # combine slices and shapes on edges in group new_slice, new_shape = reduce_slices( split_node.act_slices[from_idx:to_idx + 1], split_node.out_shapes[from_idx:to_idx + 1]) split_node.act_slices = split_node.act_slices[:from_idx] + [ new_slice ] + split_node.act_slices[to_idx + 1:] split_node.out_shapes = split_node.out_shapes[:from_idx] + [ new_shape ] + split_node.out_shapes[to_idx + 1:] # remove all edges and intermediate nodes on all edge groups except the first for edge_list in edge_group[1:]: remove_edges(G, edge_list) out_edge_bundles = G.indexed_out_edges(split_node) # move edges beyond the edge group after the first index for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]): assert len(edge_list) == 1 edge = edge_list[0] G.remove_edge(edge) G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset)) # reindex the in edges in the concat from_idx = edge_group[0][-1].to_idx to_idx = edge_group[-1][-1].to_idx in_edges = G.indexed_in_edges(concat_node) for offset, in_edge in enumerate(in_edges[to_idx + 1:]): G.remove_edge(in_edge) G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset)) return bool(edge_groups)