def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False if set_identity: self.set_identity(G) seq = Sequence( ReshapeParameters, OneOrMoreOf(AnyOf(ActivationParameters, CustomMatcher())), AnyOf(ReshapeParameters, FcParameters)) for nodes in seq.find(G): in_pattern = nodes[0].exp_red_pattern() if not in_pattern: continue has_modified_graph = True remove_both = isinstance( nodes[-1], ReshapeParameters ) and nodes[0].old_shape.shape == nodes[-1].shape.shape propagate_shape(G, nodes[0].shape, in_pattern, nodes[1:-1:]) LOG.info('removing unnecessary reshape %s', nodes[0].name) G.remove_and_reconnect(nodes[0], edge_class=NNEdge) self.remove_quantization(G, nodes[0]) if remove_both: LOG.info('removing unnecessary reshape %s', nodes[-1].name) G.remove_and_reconnect(nodes[-1], edge_class=NNEdge) self.remove_quantization(G, nodes[-1]) elif isinstance(nodes[-1], ReshapeParameters): nodes[-1].old_shape = nodes[0].shape return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes(node_classes=StridedSliceParameters) ]: if node.slice_shape != tuple(node.in_dims[0].shape): continue has_modified_graph = True nid = NodeId(node) if node.slice_shape == node.out_shape: LOG.info( f'removing strided slice {node.name} that does nothing') G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization and nid in G.quantization: del G.quantization[nid] else: reshape = ReshapeParameters( G.unique_name(f'{node.name}_reshape'), old_shape=node.slice_shape, shape=node.out_shape) LOG.info( f'replacing strided slice {node.name} with reshape {reshape.name}' ) G.replace_node(node, reshape) if G.quantization and nid in G.quantization: G.quantization[NodeId(reshape)] = G.quantization[nid] del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): nodes_removed = [] modified_graph = False for node in G.nodes(node_classes=QuantizeParameters): if issubclass(node.from_qtype.dtype, (np.floating, bfloat16)): if issubclass(node.to_qtype.dtype, (np.floating, bfloat16)): LOG.warning( 'node %s quantizes from floating type to floating type and cannot directly be removed', node.name) continue if self.propagate_up(G, node, node.to_qtype): modified_graph = True nodes_removed.append(node) G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: del G.quantization[NodeId(node)] else: LOG.warning('unable to remove quantize node %s', node.name) else: if self.propagate_down(G, node, node.from_qtype): modified_graph = True nodes_removed.append(node) G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: del G.quantization[NodeId(node)] else: LOG.warning('unable to remove quantize node %s', node.name) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = True while modified_graph: modified_graph = False for reshape in G.nodes(node_classes=(ReshapeParameters, )): if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape: modified_graph = True LOG.info('removing reshape that does nothing %s', reshape.name) G.remove_and_reconnect(reshape, edge_class=NNEdge) nid = NodeId(reshape) if G.quantization and nid in G.quantization: del G.quantization[nid] res = None for reshape in G.nodes(node_classes=(ReshapeParameters, )): res = self.validate_reshape(G, reshape) if res: LOG.info('unnecessary reshape found after %s', reshape.name) modified_graph = True (reshape, candidates, out_shape) = res for candidate in candidates: LOG.info( 'removing unnecessary reshape or transpose %s', candidate.name) edges = G.out_edges(candidate.name) G.remove(candidate) nid = NodeId(candidate) if G.quantization and nid in G.quantization: del G.quantization[nid] for edge in edges: G.add_edge( NNEdge(from_node=reshape, to_node=edge.to_node, to_idx=edge.to_idx)) reshape.shape = Dim.unnamed(out_shape) break if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False nodes_to_remove = [] for node in G.nodes(node_classes=CopyParameters): out_edges = G.out_edges(node) if len(out_edges) > 1: continue if (search_down( G, out_edges[0], (OutputParameters, InputParameters, ConstantInputParameters, SplitParameters, ConcatParameters), can_pass=(ReshapeParameters, NoOPParameters), can_pass_fn=lambda G, node: isinstance( node, TransposeParameters) and node.does_nothing, follow_multi=True) and search_up( G, G.in_edges(node)[0], (InputParameters, OutputParameters, ConstantInputParameters, SplitParameters, ConcatParameters), can_pass=(ReshapeParameters, NoOPParameters), can_pass_fn=lambda G, node: isinstance( node, TransposeParameters) and node.does_nothing, follow_multi=True)): continue nodes_to_remove.append(node) for node in nodes_to_remove: LOG.info("remove redundant copy %s", node.name) has_modified_graph = True G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: nid = NodeId(node) if nid in G.quantization: del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False quantize_nodes = G.nodes(node_classes=QuantizeParameters) while quantize_nodes: node = quantize_nodes.pop(0) if G.quantization: qrec = G.quantization.get(NodeId(node)) if not qrec: continue if deepcopy(qrec.in_qs[0]) == qrec.out_qs[0]: modified_graph = True LOG.info('removing quantize node %s from %s to %s', node.name, qrec.in_qs[0], qrec.out_qs[0]) G.remove_and_reconnect(node, edge_class=NNEdge) del G.quantization[NodeId(node)] continue next_node = self.get_single_quantize_edge(G, node) while next_node: LOG.info( 'removing quantize node %s and modifying node %s to output %s', next_node.name, node.name, next_node.to_qtype) G.remove_and_reconnect(next_node, edge_class=NNEdge) node.to_qtype = next_node.to_qtype modified_graph = True if G.quantization: this_rec = G.quantization[NodeId(node)] next_rec = G.quantization[NodeId(next_node)] this_rec.out_qs = next_rec.out_qs del G.quantization[NodeId(next_node)] next_node = self.get_single_quantize_edge(G, node) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False candidates = set(G.nodes(node_classes=(ReshapeParameters, ))) while candidates: node = candidates.pop() out_edges = G.out_edges(node.name) if len(out_edges) != 1 or not isinstance( out_edges[0].to_node, FcParameters) or out_edges[0].to_node.batch_size > 1: continue LOG.info('removing unnecessary reshape before linear %s', node.name) G.remove_and_reconnect(node, edge_class=NNEdge) modified_graph = True nid = NodeId(node) if G.quantization and G.quantization.get(nid): del G.quantization[nid] modified_graph = True if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] node_sets = self.find_sets(G) for node_set in node_sets: Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for node in node_set: frag.add_node(node) for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in node_set), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } G.quantization[NodeId(expr)] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = UnifiedQuantizer.from_quantized_graph(G) G.quantization = quantizer.quantize(G, start_nodes=to_quantize) if set_identity: self.set_identity(G) return has_modified_graph