def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) fragment.add_node(MatScaleNodeMatch()) has_modified_graph = False for frag, match in fragment.match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match['inputs'] ] matched_node = list(frag.nodes())[0] out_edges = G.out_edges(matched_node.name) has_modified_graph = True G.remove(matched_node) fnode = MatScaleFusionParameters( "{}_fusion".format(matched_node.name), fusion_type=match['type'], subgraph=frag, input_mapping=[[(matched_node, 0)], [(matched_node, 1)]]) G.add_node(fnode) for idx, edge in enumerate(match_edges): edge.to_node = fnode edge.to_idx = idx G.add_edge(edge) for edge in out_edges: edge.from_node = fnode G.add_edge(edge) if set_identity: self.set_identity(G) return has_modified_graph
def test_match4alt(): G = Graph() G.add_node(Node("test1")) G.add_node(Node('test2')) G.add_node(Node('test3')) G.add_node(Node('test4')) G.add_node(Node('test5')) G.add_edge(Edge('test1', 'test2')) G.add_edge(Edge('test2', 'test4')) G.add_edge(Edge('test3', 'test4', to_idx=1)) G.add_edge(Edge('test4', 'test5')) fragment = GraphMatcher() m2 = MatchNodeByName("test2") m3 = MatchNodeByName('test3') m4 = MatchNodeByName('test4') e1 = MatchEdgeInputsGroupFactory() fragment.add_edge(e1.get_edge(from_node=m2, to_node=m4)) fragment.add_edge(e1.get_edge(from_node=m3, to_node=m4)) res = fragment.match_graph(G) assert len(res) == 1 assert res[0].num_nodes() == 3 assert res[0].num_edges() == 2
def get_matcher(self): fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) node_1 = MatScalePairMatchFactory.MatchMatMulTwice(MatrixMulParameters) node_2 = MatScalePairMatchFactory.MatchMatMulTwice(MatrixMulParameters) fragment.add_edge( MatScalePairMatchFactory.MatScaleEdgeMatch(self, from_node=node_1, to_node=node_2)) return fragment
def test_mn3_match2(mn3_graph, caplog): tfi = TfliteImporter() G = tfi.create_graph(mn3_graph, {'load_tensors': True}) G.add_dimensions() G.adjust_order() caplog.set_level(logging.DEBUG) fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) fragment.add_node(MatScaleNodeMatch()) res = fragment.match_graph(G) assert len(res) == 16
def test_match1(): G = Graph() G.add_node(Node("test1")) G.add_node(Node('test2')) G.add_edge(Edge('test1', 'test2')) G.add_edge(Edge('test2', Node('test3'))) G.add_node(Node('test4')) G.add_edge(Edge('test3', 'test4')) assert not G.verify_edges() fragment = GraphMatcher() fragment.add_edge( MatchEdgeByIdx(from_node=MatchNodeByName("test1"), to_node=MatchNodeByName('test2'))) res = fragment.match_graph(G) assert len(res) == 1 assert len(res[0]) == 2 assert res[0].num_edges() == 1
def test_match_fail(caplog): caplog.set_level(logging.DEBUG) class MatchSuffNode(NodeMatch): def __init__(self, suff): self._suff = suff self._has_matched = False def match(self, G, node, state): if self._has_matched: return False if node.name.endswith(self._suff): self._has_matched = True return True return False def commit_match(self, G, node, state): pass def reset_match(self, G, state, node=None, init=False): self._has_matched = False G = Graph() G.add_node(Node("test1a")) G.add_edge(Edge('test1a', Node('test2a'))) G.add_edge(Edge('test2a', Node('test3b'))) G.add_edge(Edge('test3b', Node('test4a'))) fragment = GraphMatcher() fragment.add_edge( MatchEdgeByIdx(from_node=MatchSuffNode('a'), to_node=MatchSuffNode('b'))) res = fragment.match_graph(G) assert len(res) == 1 assert res[0].num_nodes() == 2 assert res[0].num_edges() == 1 assert set([node.name for node in res[0].nodes()]) == set(['test2a', 'test3b'])
def test_match3(caplog): caplog.set_level(logging.DEBUG) G = Graph() G.add_node(Node("test1")) G.add_node(Node('test2')) G.add_edge(Edge('test1', 'test2')) G.add_edge(Edge('test1', Node('test3'))) G.add_node(Node('test4')) G.add_edge(Edge('test2', 'test4')) # not the same - look here fragment = GraphMatcher() n1 = MatchNodeByName('test1') fragment.add_edge( MatchEdgeByIdx(from_node=n1, to_node=MatchNodeByName('test2'))) fragment.add_edge( MatchEdgeByIdx(from_node=n1, to_node=MatchNodeByName('test3'))) res = fragment.match_graph(G) assert len(res) == 1 assert res[0].num_nodes() == 3 assert res[0].num_edges() == 2
def test_match4(): G = Graph() G.add_node(Node("test1")) G.add_node(Node('test2')) G.add_node(Node('test3')) G.add_node(Node('test4')) G.add_node(Node('test5')) G.add_edge(Edge('test1', 'test2')) G.add_edge(Edge('test2', 'test4')) G.add_edge(Edge('test3', 'test4', to_idx=1)) G.add_edge(Edge('test4', 'test5')) fragment = GraphMatcher() m2 = MatchNodeByName("test2") m3 = MatchNodeByName('test3') m4 = MatchNodeByName('test4') fragment.add_edge(MatchEdgeByIdx(from_node=m2, to_node=m4)) fragment.add_edge(MatchEdgeByIdx(from_node=m3, to_node=m4, to_idx=1)) res = fragment.match_graph(G) assert len(res) == 1 assert res[0].num_nodes() == 3 assert res[0].num_edges() == 2