def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def walk_down(subg: GraphView, node: Parameters, node_slices: Mapping[Parameters, Sequence[SlicedTensor]]): # edges not created if isinstance(node, FusionInputParameters): inp_slice = InputSlice.from_shape(node, node.dims.shape) dim_slices = node_slices[node] = [ SlicedTensor([ SliceElement(tuple([0] * inp_slice.rank), inp_slice.shape, inp_slice) ]) ] else: dim_slices = node_slices.get(node) if dim_slices is None: return # all edges not created if len(dim_slices) < subg.num_in_edges(node.name): return # all edges not created if any(val is None for val in dim_slices): return if isinstance(node, Transposable) and node.transpose_in: for idx, transpose in enumerate(node.transpose_in): if transpose: dim_slices[idx] = dim_slices[idx].transpose(transpose) if isinstance(node, ConcatParameters): dim_slices = [SlicedTensor.concat(*dim_slices, axis=node.axis)] elif isinstance(node, SplitParameters): dim_slices = dim_slices[0].split(node.act_slices) elif isinstance(node, StridedSliceParameters): dim_slices = [dim_slices[0].slice(node.act_slice)] if isinstance(node, Transposable) and node.transpose_out: for idx, transpose in enumerate(node.transpose_out): if transpose: dim_slices[idx] = dim_slices[idx].transpose(transpose) # set output edges for edge_set in subg.indexed_out_edges(node.name): for edge in edge_set: dest_slices = node_slices.setdefault(edge.to_node, [None] * (edge.to_idx + 1)) if len(dest_slices) < edge.to_idx + 1: dest_slices = dest_slices + \ ([None] * ((edge.to_idx + 1) - len(dest_slices))) node_slices[edge.to_node] = dest_slices dest_slices[edge.to_idx] = dim_slices[edge.from_idx] # explore graph for edge_set in subg.indexed_out_edges(node.name): for edge in edge_set: walk_down(subg, edge.to_node, node_slices)
def _match(self, G: GraphView, node: Node, edge: Edge): return isinstance(node, ConcatParameters) and G.num_in_edges( node.name) == 1