def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for node_set in self.find_sets(G): has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.info('matched expression - creating expression %s', self._expr_num) expr = ExpressionFusionParameters(f"expr_{self._expr_num}", subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping) self._expr_num += 1 if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False # collect connected node sets node_sets = group_nodes(G, [ node for node in G.nodes() if isinstance(node, FUSE_NODES) or ( isinstance(node, ConstantInputParameters) and node.out_dims[0].size() == 1) ]) # remove sets that are only ConstantInputs node_sets = [ node_set for node_set in node_sets if not all( isinstance(node, ConstantInputParameters) for node in node_set) ] for node_set in node_sets: has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] out_mapping = list(out_edges.keys()) constant_inputs = [ isinstance(node_edge_idx[0], ConstantInputParameters) for node_edge_idx in in_edges ] expr = ExpressionFusionParameters("expr_%s" % self._expr_num, subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, constant_inputs=constant_inputs) G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=sorted(list(in_edges.keys()), key=lambda x: x[1]), edge_out_mapping=[[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()]) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: replaced = True has_modified_graph = False while replaced: replaced = False for subgraph in self.match_function(G): # TODO - Save in and out edges here since the replace function may modify the # subgraph in_edges = [ in_edge for input_node in subgraph.inputs() for in_edge in G.in_edges(input_node.name) ] out_edges = [ out_edge for output_node in subgraph.outputs() for out_edge in G.out_edges(output_node.name) ] try: replacement, edge_in_mapping, edge_out_mapping = self.replace_function( G, subgraph) if replacement is None: G.remove_fragment(subgraph) has_modified_graph = True elif isinstance(replacement, Node): # use saved in and out edges G.replace_fragment(subgraph, replacement, frag_in_edges=in_edges, frag_out_edges=out_edges, edge_in_mapping=edge_in_mapping, edge_out_mapping=edge_out_mapping) has_modified_graph = True else: raise TypeError( "unexcepted return value from replace_function") replaced = True break except DontReplaceError: pass if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): replaced = True while replaced: replaced = False for subgraph in self.match_function(G): replacement = self.replace_function(G, subgraph) if not replacement: G.remove_fragment(subgraph) elif isinstance(replacement, Node): G.replace_fragment(subgraph, replacement) else: raise TypeError( "unexcepted return value from replace_function") replaced = True break if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] node_sets = self.find_sets(G) for node_set in node_sets: Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for node in node_set: frag.add_node(node) for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in node_set), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } G.quantization[NodeId(expr)] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = UnifiedQuantizer.from_quantized_graph(G) G.quantization = quantizer.quantize(G, start_nodes=to_quantize) if set_identity: self.set_identity(G) return has_modified_graph