def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # if there are quantization stats then clear them. They need to be created again G.quantization.stats = None qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): # Only works for reverses connected to one RNN node reverse_nodes = set([ node for node in G.nodes() if (isinstance(node, ReverseParameters) and len(G.out_edges(node.name)) == 1 and isinstance( G.out_edges(node.name)[0].to_node, RNNBaseParameters)) ]) has_modified_graph = False for reverse_node in reverse_nodes: in_edges = G.in_edges(reverse_node.name) rnn_edge = G.out_edges(reverse_node.name)[0] if rnn_edge.to_idx != 0: LOG.warning("reverse on rnn input %s", rnn_edge.to_idx) continue assert not rnn_edge.to_node.revert, "RNN node is already reversed!" rnn_edge.to_node.revert = True LOG.info("fusing reverses into node %s", rnn_edge.to_node.name) has_modified_graph = True G.remove(reverse_node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, rnn_edge.to_node, from_idx=edge.from_idx, to_idx=rnn_edge.to_idx)) for edge in G.out_edges(rnn_edge.to_node.name): if not isinstance(edge.to_node, ReverseParameters): continue if edge.from_idx != 0: LOG.warning("reverse on rnn output %s", edge.from_idx) continue rev_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) for rev_edge in rev_edges: G.add_edge( NNEdge(edge.from_node, rev_edge.to_node, from_idx=edge.from_idx, to_idx=rev_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for pad_params in [pad for pad in G.nodes() if isinstance(pad, PadParameters)]: pad_in_edges = G.in_edges(pad_params.name) pad_out_edges = G.out_edges(pad_params.name) dont_delete = False if len(pad_in_edges) == 1 and all(sum(padding) == 0 for padding in pad_params.padding): LOG.info("removing zero padding node %s", pad_params.name) G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) dont_delete = True in_edge = pad_in_edges[0] for out_edge in pad_out_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) else: for pad_out_edge in pad_out_edges: filter_like_node, expanded_padding, reshapes = self.find_conv( G, pad_out_edge.to_node, pad_params.padding) if not filter_like_node: dont_delete = True continue if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[0]: raise ValueError( f"filter {filter_like_node.name} doesn't have a input hint") in_hint = filter_like_node.in_dims_hint[0] hinted_pad = {in_hint[idx]: pad for idx, pad in enumerate(expanded_padding) if sum(pad) > 0} key_set = set(hinted_pad.keys()) key_set -= set(['h', 'w']) if len(key_set) > 0: dont_delete = True LOG.error("node %s has padding on axes %s and cannot be fused with filter %s", pad_params.name, key_set, filter_like_node.name) continue if any(pval != 0 for val in pad_params.pad_vals for pval in val): dont_delete = True LOG.error("node %s has non zero pad values and cannot be fused with filter %s", pad_params.name, filter_like_node.name) continue LOG.info("adding padding from: %s to filter: %s - has %s reshapes", pad_params.name, filter_like_node.name, len(reshapes)) for key in ['h', 'w']: if key not in hinted_pad: hinted_pad[key] = (0, 0) filter_like_node.padding = PadDim( *(list(hinted_pad['h']) + list(hinted_pad['w']))) filter_like_node.pad_type = "zero" has_modified_graph = True G.remove_edge(pad_out_edge) for reshape_node, old_padding, new_padding in reshapes: reshape_node.old_shape = self.remove_padding( reshape_node.old_shape, old_padding) reshape_node.shape = self.remove_padding( reshape_node.shape, new_padding) for in_edge in pad_in_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, to_node=pad_out_edge.to_node, from_idx=in_edge.from_idx, to_idx=pad_out_edge.to_idx)) if not dont_delete: G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1 ] rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes) rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice) if not rnn_nodes_by_slice: return False for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items(): modified_nodes = set() for rnn_unpack in rnn_unpacks: self.process_path(G, rnn_unpack, modified_nodes) # since process path will have removed all unnecessary nodes the edges will be correct here out_edges = G.out_edges(unpack_node.name) in_edges = G.in_edges(unpack_node.name) assert len(in_edges ) == 1, "expecting unpack node to have only one in edge" in_edge = in_edges[0] changes_shape = unpack_node.changes_shape if isinstance( unpack_node, StridedSliceParameters) else False LOG.info("Eliminating last cell unpack: %s", unpack_node.name) G.remove(unpack_node) # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=reshape, from_idx=in_edge.from_idx)) for out_edge in out_edges: G.add_edge( NNEdge(from_node=reshape, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: for out_edge in out_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return True
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(ConcatMatcher('0')) return G.match_fragment(sub)
def match_function(self, G: GraphView): sub = GraphView() sub.add_node( MatchNode( '0', matcher=lambda node: isinstance(node, ReluActivationParameters ) and node.upper_bound == 6)) sub.add_node( MatchNode( '1', matcher=lambda node: isinstance(node, MatrixMulParameters))) sub.add_node( MatchNode( '2', matcher=lambda node: isinstance(node, ConstantInputParameters) and check_equals(G, node, 1.0 / 6.0))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: edge_groups = [] for node in G.nodes(node_classes=SplitParameters): cur_group = None for out_edge_bundle in G.indexed_out_edges(node): if len(out_edge_bundle) == 1: out_edge = out_edge_bundle[0] concat_node_edges = search_down(G, out_edge, ConcatParameters, can_pass=(CopyParameters, NoOPParameters)) if concat_node_edges: if cur_group: this_concat_edge = concat_node_edges[-1] last_concat_edge = cur_group[-1][-1] if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1: cur_group.append(concat_node_edges) continue if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = [concat_node_edges] continue if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None if cur_group: if len(cur_group) > 1: edge_groups.append(cur_group) cur_group = None # we leave the splits and concats after this since they will be cleared up by remove_noops for edge_group in edge_groups: split_node = edge_group[0][0].from_node concat_node = edge_group[0][-1].to_node from_idx = edge_group[0][0].from_idx to_idx = edge_group[-1][0].from_idx LOG.info( f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}" ) # combine slices and shapes on edges in group new_slice, new_shape = reduce_slices( split_node.act_slices[from_idx:to_idx + 1], split_node.out_shapes[from_idx:to_idx + 1]) split_node.act_slices = split_node.act_slices[:from_idx] + [ new_slice ] + split_node.act_slices[to_idx + 1:] split_node.out_shapes = split_node.out_shapes[:from_idx] + [ new_shape ] + split_node.out_shapes[to_idx + 1:] # remove all edges and intermediate nodes on all edge groups except the first for edge_list in edge_group[1:]: remove_edges(G, edge_list) out_edge_bundles = G.indexed_out_edges(split_node) # move edges beyond the edge group after the first index for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]): assert len(edge_list) == 1 edge = edge_list[0] G.remove_edge(edge) G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset)) # reindex the in edges in the concat from_idx = edge_group[0][-1].to_idx to_idx = edge_group[-1][-1].to_idx in_edges = G.indexed_in_edges(concat_node) for offset, in_edge in enumerate(in_edges[to_idx + 1:]): G.remove_edge(in_edge) G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset)) return bool(edge_groups)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False slices_by_origin = {} for slice_node in [ node for node in G.nodes() if isinstance(node, StridedSliceParameters) ]: in_edge = G.in_edges(slice_node.name)[0] group = slices_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(slice_node) for in_edge, slice_nodes in slices_by_origin.items(): slices = list(zip(*[node.act_slice for node in slice_nodes])) if len(slice_nodes) == 1: self.slice_to_split(G, slice_nodes, slices) continue diff_slices = [(idx, elems) for idx, elems in enumerate(slices) if not all(elems[0] == elem for elem in elems[1::])] if len(diff_slices) != 1: continue # strides must be one if any(sl[2] != 1 for sl in diff_slices[0][1]): continue # check if slices are consecutive and non overlapping slices = sorted(diff_slices[0][1], key=lambda x: x[0]) if not all(sl[0] + sl[1] == slices[i + 1][0] for i, sl in enumerate(slices[:-1:])): continue szes = [sl[1] - sl[0] for sl in slices] axis = diff_slices[0][0] slice_nodes = sorted(slice_nodes, key=lambda x: x.act_slice[axis][0]) act_slices, out_shapes, axis = SplitParameters.get_splits( slice_nodes[0].in_dims[0].shape, axis, splits=szes) params = SplitParameters(slice_nodes[0].name + '_split', act_slices=act_slices, out_shapes=out_shapes, axis=axis) in_edge = G.in_edges(slice_nodes[0].name)[0] G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=params, from_idx=in_edge.from_idx)) sub_names = [] for idx, node in enumerate(slice_nodes): sub_names.append(node.name) out_edges = G.out_edges(node.name) G.remove(node) for out_edge in out_edges: G.add_edge( NNEdge(from_node=params, to_node=out_edge.to_node, from_idx=idx, to_idx=out_edge.to_idx)) if G.quantization: G.add_dimensions() quantizer = UnifiedQuantizer.from_quantized_graph(G) quantizer.quantize(G, start_nodes=[params]) RemoveUnnecessaryQuantizeOperators().match(G) LOG.info( f'replaced slice nodes {",".join(sub_names)} with split node {sub_names[0]}' ) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True) -> bool: has_modified_graph = False gathers_by_origin = {} for gather in [ node for node in G.nodes() if isinstance(node, GatherParameters) ]: in_edge = G.in_edges(gather.name)[0] group = gathers_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(gather) for in_edge, gathers in gathers_by_origin.items(): # This is too difficult to handle if there are multiple slices axis = gathers[0].axis if not all(gather.axis == axis and len(gather.indices.shape) <= 1 for gather in gathers[1::]): continue # sort all the indices gathers = sorted(gathers, key=lambda x: x.indices if len(x.indices.shape) == 0 else x.indices[0]) indices = [ elem for gather in gathers for elem in ([int(gather.indices)] if len(gather.indices.shape) == 0 else list(gather.indices)) ] # All the indices must be independant and sum to the out dim (this could be relaxed but # then needs to handle gaps) in_shape = in_edge[0].out_dims[in_edge[1]].shape in_shape_without_axis = in_shape[:axis:] + in_shape[axis + 1::] if len(set(indices)) != len(indices) and len( set(indices)) == in_shape[axis]: continue # good for a split LOG.info("gathers from %s[%s] converted to a split", in_edge[0].name, in_edge[1]) splits = [] shapes = [] out_edges = [] for gather in gathers: splits.append( [tuple([int(gather.indices), int(gather.indices) + 1, 1])]) shapes.append(in_shape_without_axis) out_edges.append(G.out_edges(gather.name)) G.remove(gather) params = SplitParameters("%s_split" % in_edge[0].name, act_slices=splits, out_shapes=shapes, axis=axis) if axis != 0: trans = [axis] + list(range(0, axis)) + list( range(axis, len(in_shape))) params.transpose_out = [[ trans.index(idx) for idx in range(len(trans)) ]] params.transpose_in = [trans] for idx, edges in enumerate(out_edges): for edge in edges: G.add_edge( NNEdge(from_node=params, to_node=edge.to_node, from_idx=idx, to_idx=edge.to_idx)) G.add_edge( NNEdge(from_node=in_edge[0], to_node=params, from_idx=in_edge[1])) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): # get a list of all the nodes that are transposable but not transposes # Need to do this first to avoid mutating it when doing the modifications tnodes = list( filter( lambda n: isinstance(n, Transposable) and not isinstance( n, TransposeParameters), G.nodes())) has_modified_graph = False for node in tnodes: if node.transpose_in: for idx, edge in enumerate(G.in_edges(node.name)): if edge.to_idx >= len(node.transpose_in): continue trans = node.transpose_in[edge.to_idx] if trans is None: continue LOG.info("Expand transpose in on node %s", node.name) has_modified_graph = True in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx), transpose=trans, block_search_up=True) if node.in_dims_hint and node.in_dims_hint[edge.to_idx]: in_hint = node.in_dims_hint[edge.to_idx] out_hint = apply_transpose_to_hint(in_hint, trans) in_params.in_dims_hint = [in_hint.copy()] in_params.out_dims_hint = [out_hint.copy()] node.in_dims_hint[edge.to_idx] = out_hint if G.quantization: G.quantization.copy_qrec(node, 'in', edge.to_idx, in_params) G.insert_node(in_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_in = None if node.transpose_out: for idx, edge in enumerate(G.out_edges(node.name)): if edge.from_idx >= len(node.transpose_out): continue trans = node.transpose_out[edge.from_idx] if trans is None: continue LOG.info("Expand transpose out on node %s", node.name) has_modified_graph = True out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx), transpose=trans, block_search_down=True) if node.out_dims_hint: out_hint = node.out_dims_hint[edge.from_idx] in_hint = apply_reverse_transpose_to_hint( out_hint, trans) out_params.in_dims_hint = [in_hint.copy()] out_params.out_dims_hint = [out_hint.copy()] node.out_dims_hint[edge.from_idx] = in_hint if G.quantization: G.quantization.copy_qrec(node, 'out', edge.from_idx, out_params) G.insert_node(out_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_out = None if set_identity: self.set_identity(G) return has_modified_graph
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, FilterParameters))) sub.add_node(MatchNode('1', matcher=lambda node:\ isinstance(node, MatrixAddParameters))) sub.add_node(MatchNode('2', matcher=lambda node:\ isinstance(node, ConstantInputParameters))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, Conv2DParameters) and\ self.valid_convolution(node))) if self.match_activation and self.match_pool: if self.pool_after_activation: self.add_activation('1', sub) self.add_pooling('2', sub) else: self.add_pooling('1', sub) self.add_activation('2', sub) sub.add_edge(Edge('0', '1')) sub.add_edge(Edge('1', '2')) elif self.match_activation: self.add_activation('1', sub) sub.add_edge(Edge('0', '1')) elif self.match_pool: self.add_pooling('1', sub) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes() if self.node_does_nothing(G, node) ]: has_modified_graph = True in_edge = G.in_edges(node.name)[0] G.remove_edge(in_edge) for out_edge in G.out_edges(node.name): G.remove_edge(out_edge) G.add_edge( NNEdge(in_edge.from_node, out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) LOG.info(f'removing {node.name} that does nothing') G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False filter_nodes = [node for node in G.nodes( ) if isinstance(node, FilterParameters)] for params in filter_nodes: filter_node = params seen_reshape = [] while True: out_edges = G.out_edges(params.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node if isinstance(op_node, ReshapeParameters): seen_reshape = [op_node] params = op_node continue # must be a valid matrix op if not isinstance(op_node, tuple(OPS.keys())): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_c = filter_node.filter.out_c op, weights_and_biases = OPS[op_node.__class__] # it would be possible to support mult bias addition by out channel but only supporting a # scalar at present if len(flat_value) != 1 and (weights_and_biases or len(flat_value) != out_c): LOG.warning('could not absorb %s into %s', const_node.name, filter_node.name) break # If there is quantization then essentially the output of the filter # takes the quantization of the output of the operation. # The biases will not change since their quantization depends on the weights # and input fnid = NodeId(filter_node) opnid = NodeId(op_node) if G.quantization and (fnid in G.quantization or opnid in G.quantization): if not (fnid in G.quantization and opnid in G.quantization): LOG.warning( 'could not absorb %s into %s - graph is partially quantized', const_node.name, filter_node.name) break fqrec = G.quantization[fnid] opqrec = G.quantization[opnid] fqrec.out_qs[0] = opqrec.out_qs[0] has_modified_graph = True LOG.info("fusing bias in %s into %s", const_node.name, filter_node.name) self.fuse_bias(G, filter_node, other_idx, op, flat_value, 2) if weights_and_biases: # TODO - need to adjust weights quantization here LOG.info("fusing multiplicative bias in %s into %s", const_node.name, filter_node.name) self.fuse_bias(G, filter_node, other_idx, op, flat_value, 1) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) from_node = seen_reshape[-1] if seen_reshape else filter_node for edge in out_edges: G.add_edge(NNEdge(from_node=from_node, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False concats = set(G.nodes(node_classes=ConcatParameters)) while concats: concat = concats.pop() if concat.axis != 0: continue subgraph = find_concats_up(G, concat) found = set(subgraph.nodes(node_classes=ConcatParameters)) if len(found) <= 1: continue LOG.info( f"Combining concats {','.join([node.name for node in found])}") modified_graph = True concats -= found in_edges = [inp.edge for inp in subgraph.inputs()] in_dims = [ edge.from_node.out_dims[edge.from_idx] for edge in in_edges ] nodes_to_remove = [ node for node in subgraph.nodes() if node != concat and not isinstance(node, DummyInput) ] for edge in in_edges: G.remove_edge(edge) for node in nodes_to_remove: if node.name in G: G.remove(node) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] # remove_internal_graph(G, subgraph) out_dim = concat.out_dims[0] in_qs = [] for idx, edge in enumerate(in_edges): from_node = edge.from_node from_idx = edge.from_idx if len(in_dims[idx]) > 1: reshape = ReshapeParameters( G.unique_name(f'{concat.name}_flat{idx}'), old_shape=in_dims[idx], shape=Dim.unnamed([in_dims[idx].size()])) G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=reshape)) from_node = reshape from_idx = 0 G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=concat, to_idx=idx)) if in_qs is not None and G.quantization: nid = NodeId(edge.from_node) if nid in G.quantization: qrec = G.quantization[nid] in_qs.append(qrec.out_qs[edge.from_idx]) else: in_qs = None else: in_qs = None if in_qs is not None and G.quantization: nid = NodeId(concat) if nid in G.quantization: G.quantization[nid].in_qs = in_qs reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'), old_shape=Dim.unnamed([out_dim.size() ]), shape=out_dim) G.insert_node_after(concat, reshape, edge_class=NNEdge) if set_identity: self.set_identity(G) return modified_graph
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return for nid in [ nid for nid, qrec in G.quantization.sorted_iterator(G) if qrec is None or not (qrec.in_qs and qrec.out_qs) ]: if nid.fnode_name: LOG.warning("can't add quantization to fused node %s", nid.fnode_name) continue if nid.node_name not in G: # previous fusions may have removed nodes from the graph continue node = nid.get_node(G) predecessors = [NodeId(pred) for pred in G.predecessors(node.name)] successors = [ NodeId(succ) for succs in G.successors(node.name) for succ in succs ] go_back = not successors or (predecessors and all(pred in G.quantization for pred in predecessors)) go_forward = not predecessors or (successors and all(succ in G.quantization for succ in successors)) if not (go_back or go_forward): LOG.warning( "node %s is not connected to anything and has no quantization", node.name) continue if go_forward: out_qrecs = set(G.quantization[nid] for nid in successors) if not all( isinstance(out_qrec, MultQuantizationRecord) for out_qrec in out_qrecs): continue out_qtypes = reduce_qtypes([ (edge.from_idx, G.quantization[NodeId(edge.to_node)].in_qs[edge.to_idx]) for edge in G.out_edges(node.name) ]) else: out_qtypes = None if go_back: in_qrecs = set(G.quantization[nid] for nid in predecessors) if not all( isinstance(in_qrec, MultQuantizationRecord) for in_qrec in in_qrecs): continue in_qtypes = reduce_qtypes([(edge.to_idx, G.quantization[NodeId( edge.from_node)].out_qs[edge.from_idx]) for edge in G.in_edges(node.name)]) else: in_qtypes = None if not in_qtypes: if not predecessors: LOG.info("setting quantization on input node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(out_qtypes), out_qs=deepcopy(out_qtypes)) else: raise NotImplementedError( "propagating qrecs not implemented") elif not out_qtypes: if not successors: LOG.info("setting quantization on output node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(in_qtypes)) else: raise NotImplementedError( "propagating qrecs not implemented") else: LOG.info("setting quantization on node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(out_qtypes)) G.quantization[nid] = qrec if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] for frag in find_connected_groups(G): Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges = external_edges(G, frag) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in frag.nodes()), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } nid = NodeId(expr) G.quantization[nid] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') if G.quantization.stats: G.quantization.stats[nid] = { 'range_in': [{ 'min': qtype.min_val, 'max': qtype.max_val } for qtype in in_qs], 'range_out': [{ 'min': qtype.min_val, 'max': qtype.max_val } for qtype in out_qs], 'expression': stats.copy() } # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node, edge_class=NNEdge) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, node: Node, edge: Edge): return isinstance(node, ConcatParameters) and G.num_in_edges(node.name) == 1
def match(self, G: GraphView, set_identity: bool = True): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) ] has_modified_graph = False for rnn_unpack in rnn_nodes: if not rnn_unpack: continue unpack_node = rnn_unpack[-1] rnn_node = rnn_unpack[0] time_axis = 0 if isinstance(unpack_node, StridedSliceParameters): if unpack_node.act_slice[time_axis][1] != rnn_node.n_cells: LOG.debug("can't remove %s. Slice not equal to cells", unpack_node.name) continue if unpack_node.act_slice[time_axis][2] != 1: LOG.debug("can't remove %s. Slice not of length 1", unpack_node.name) continue if unpack_node.act_slice[time_axis][0] != rnn_node.n_cells - 1: LOG.debug("can't remove %s. Slice isn't last cell", unpack_node.name) continue out_edge = G.out_edges(unpack_node.name)[0] changes_shape = unpack_node.changes_shape elif isinstance(unpack_node, SplitParameters): out_edges = G.out_edges(unpack_node.name) if len(out_edges) > 1: LOG.debug("can't remove %s. More than one output edge", unpack_node.name) continue out_edge = out_edges[0] if out_edge.from_idx != len(unpack_node.act_slices) - 1: LOG.debug("can't remove %s. Not last output", unpack_node.name) continue act_slice = unpack_node.act_slices[-1] if act_slice[time_axis][1] != rnn_node.n_cells: LOG.debug("can't remove %s. Slice not equal to cells", unpack_node.name) continue if act_slice[time_axis][0] != rnn_node.n_cells - 1: LOG.debug("can't remove %s. Slice isn't last cell", unpack_node.name) continue changes_shape = False out_edge = G.out_edges(unpack_node.name)[0] else: continue has_modified_graph = True LOG.info("Eliminating last cell unpack: %s", unpack_node.name) for node in rnn_unpack[1:-1:]: LOG.info("Eliminating others: %s", node.name) if G.quantization: del G.quantization[NodeId(node)] G.remove(node) G.remove(unpack_node) rnn_node.n_output_cells = 1 rnn_node.out_dims[0] = unpack_node.out_dims[out_edge.from_idx] if unpack_node.out_dims_hint and unpack_node.out_dims_hint[ out_edge.from_idx]: rnn_node.out_dims_hint = [ unpack_node.out_dims_hint[out_edge.from_idx] ] else: rnn_node.out_dims_hint = None # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge(NNEdge(rnn_node, reshape)) G.add_edge( NNEdge(reshape, out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: G.add_edge( NNEdge(rnn_node, out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True) -> bool: has_modified_graph = False for split_node in set([node for node in G.nodes() if isinstance(node, SplitParameters)]): in_edges = G.in_edges(split_node.name) if len(in_edges) > 1: continue in_edge = in_edges[0] if not isinstance(in_edge.from_node, ConcatParameters): continue concat_node = in_edge.from_node if len(G.out_edges(concat_node.name)) > 1: continue if concat_node.transpose_out or split_node.transpose_in: continue if concat_node.axis != split_node.axis: continue axis = concat_node.axis split_out_sizes = [out_shape[axis] for out_shape in split_node.out_shapes] if len(split_out_sizes) != len(concat_node.in_dims): continue if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)): continue has_modified_graph = True LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name) concat_in_edges = G.indexed_in_edges(concat_node.name) split_out_edges = G.indexed_out_edges(split_node.name) G.remove(split_node) G.remove(concat_node) for idx, in_edge in enumerate(concat_in_edges): for out_edge in split_out_edges[idx]: G.add_edge(NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False has_transposed = False for params in G.nodes(node_classes=MatMulOpParameters): matmul = params seen_reshape = [] while True: out_edges = G.out_edges(params.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node if isinstance(op_node, ReshapeParameters): seen_reshape.append(op_node) params = op_node continue # must be a valid matrix op if not isinstance(op_node, (MatrixAddParameters, MatrixMulParameters)): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_shape = matmul.out_dims[0].shape if len(out_shape) != 2: raise ValueError( f'strange outputs shape of {out_shape} for matmul {params.name}' ) if len(flat_value) != out_shape[0] and len( flat_value) != out_shape[1]: LOG.info( "can't fuse %s into %s - value shape is not correct for bias", const_node.name, matmul.name) break has_bias = len(matmul.in_dims) == 3 in_nodes = [matmul] out_node = seen_reshape[-1] if seen_reshape else matmul if isinstance(op_node, MatrixAddParameters): if has_bias: if len(flat_value.shape) != len(matmul.in_dims[2]): LOG.info( "can't fuse %s into %s - bias shape is not the same", const_node.name, matmul.name) break bias_node = G.indexed_in_edges( matmul.name)[2].from_node LOG.info( "folding additive bias from %s into existing bias on %s", op_node.name, matmul.name) bias_node.value = bias_node.dq_value + flat_value else: if len(flat_value) == out_shape[1]: # matmul needs to be transposed to fuse this in_nodes, trans_node = reverse_matmul(G, matmul) if seen_reshape: out_node = seen_reshape[-1] else: out_node = trans_node has_transposed = True bias_node = ConstantInputParameters( G.unique_name(f'{matmul.name}_bias'), value=flat_value, dims=Dim.unnamed(flat_value.shape)) G.add_edge( NNEdge(from_node=bias_node, to_node=matmul, to_idx=2)) LOG.info( "folding additive bias from %s into new bias on %s", op_node.name, matmul.name) else: params_in = G.indexed_in_edges(matmul.name) consts = [ isinstance(edge.from_node, ConstantInputParameters) for edge in params_in ] if not any(consts): break mult_const_node = params_in[1].from_node if consts[ 1] else params_in[0].from_node mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue if has_bias: bias_node = params_in[2].from_node bias_node.value = bias_node.dqvalue * const_node.dqvalue LOG.info( "folding multaplicative bias from %s into new bias on %s", op_node.name, matmul.name) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) for edge in out_edges: G.add_edge( NNEdge(from_node=out_node, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_dimensions() if G.quantization: quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() RemoveUnnecessaryQuantizeOperators().match(G) if has_transposed: G.adjust_order() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) # TODO - stats if qrecs: prec = QRec.copy_ktype(qrecs[1], in_qs=qrecs[1].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + \ G.indexed_in_edges(node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph