def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) fragment.add_node(MatScaleNodeMatch()) has_modified_graph = False for frag, match in fragment.match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match['inputs'] ] matched_node = list(frag.nodes())[0] out_edges = G.out_edges(matched_node.name) has_modified_graph = True G.remove(matched_node) fnode = MatScaleFusionParameters( "{}_fusion".format(matched_node.name), fusion_type=match['type'], subgraph=frag, input_mapping=[[(matched_node, 0)], [(matched_node, 1)]]) G.add_node(fnode) for idx, edge in enumerate(match_edges): edge.to_node = fnode edge.to_idx = idx G.add_edge(edge) for edge in out_edges: edge.from_node = fnode G.add_edge(edge) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fac = MatScalePairMatchFactory() has_modified_graph = False for frag, match in fac.get_matcher().match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match ] first_node = frag.inputs()[0] last_node = frag.outputs()[0] out_edges = G.out_edges(last_node.name) for node in frag.nodes(): G.remove(node) input_mapping = MatScaleFusionParameters.get_mapping_from_edges( match_edges) fnode = MatScaleFusionParameters( "{}_{}_fusion".format(first_node.name, last_node.name), fusion_type="vec_scalar", subgraph=frag, input_mapping=MatScaleFusionParameters.convert_input_mapping( input_mapping)) has_modified_graph = True G.add_node(fnode) fnode.in_dims_hint = [None] * 3 for idx, edge in enumerate(match_edges): new_edge = edge.clone( to_node=fnode, to_idx=list(input_mapping[edge.to_node].keys())[0]) if new_edge.from_node.out_dims_hint: fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[ edge.from_idx] G.add_edge(new_edge) for edge in out_edges: new_edge = edge.clone(from_node=fnode) G.add_edge(new_edge) if set_identity: self.set_identity(G) return has_modified_graph