Esempio n. 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 2
0
def test_match4alt():
    G = Graph()
    G.add_node(Node("test1"))
    G.add_node(Node('test2'))
    G.add_node(Node('test3'))
    G.add_node(Node('test4'))
    G.add_node(Node('test5'))
    G.add_edge(Edge('test1', 'test2'))
    G.add_edge(Edge('test2', 'test4'))
    G.add_edge(Edge('test3', 'test4', to_idx=1))
    G.add_edge(Edge('test4', 'test5'))

    fragment = GraphMatcher()
    m2 = MatchNodeByName("test2")
    m3 = MatchNodeByName('test3')
    m4 = MatchNodeByName('test4')
    e1 = MatchEdgeInputsGroupFactory()

    fragment.add_edge(e1.get_edge(from_node=m2, to_node=m4))
    fragment.add_edge(e1.get_edge(from_node=m3, to_node=m4))

    res = fragment.match_graph(G)
    assert len(res) == 1
    assert res[0].num_nodes() == 3
    assert res[0].num_edges() == 2
Esempio n. 3
0
 def get_matcher(self):
     fragment = GraphMatcher(
         match_function=lambda state, frag: (frag, state['match']))
     node_1 = MatScalePairMatchFactory.MatchMatMulTwice(MatrixMulParameters)
     node_2 = MatScalePairMatchFactory.MatchMatMulTwice(MatrixMulParameters)
     fragment.add_edge(
         MatScalePairMatchFactory.MatScaleEdgeMatch(self,
                                                    from_node=node_1,
                                                    to_node=node_2))
     return fragment
Esempio n. 4
0
def test_mn3_match2(mn3_graph, caplog):
    tfi = TfliteImporter()
    G = tfi.create_graph(mn3_graph, {'load_tensors': True})
    G.add_dimensions()
    G.adjust_order()
    caplog.set_level(logging.DEBUG)
    fragment = GraphMatcher(
        match_function=lambda state, frag: (frag, state['match']))
    fragment.add_node(MatScaleNodeMatch())
    res = fragment.match_graph(G)
    assert len(res) == 16
Esempio n. 5
0
def test_match1():
    G = Graph()
    G.add_node(Node("test1"))
    G.add_node(Node('test2'))
    G.add_edge(Edge('test1', 'test2'))
    G.add_edge(Edge('test2', Node('test3')))
    G.add_node(Node('test4'))
    G.add_edge(Edge('test3', 'test4'))
    assert not G.verify_edges()

    fragment = GraphMatcher()
    fragment.add_edge(
        MatchEdgeByIdx(from_node=MatchNodeByName("test1"),
                       to_node=MatchNodeByName('test2')))

    res = fragment.match_graph(G)
    assert len(res) == 1
    assert len(res[0]) == 2
    assert res[0].num_edges() == 1
Esempio n. 6
0
def test_match_fail(caplog):
    caplog.set_level(logging.DEBUG)

    class MatchSuffNode(NodeMatch):
        def __init__(self, suff):
            self._suff = suff
            self._has_matched = False

        def match(self, G, node, state):
            if self._has_matched:
                return False
            if node.name.endswith(self._suff):
                self._has_matched = True
                return True
            return False

        def commit_match(self, G, node, state):
            pass

        def reset_match(self, G, state, node=None, init=False):
            self._has_matched = False

    G = Graph()
    G.add_node(Node("test1a"))
    G.add_edge(Edge('test1a', Node('test2a')))
    G.add_edge(Edge('test2a', Node('test3b')))
    G.add_edge(Edge('test3b', Node('test4a')))

    fragment = GraphMatcher()
    fragment.add_edge(
        MatchEdgeByIdx(from_node=MatchSuffNode('a'),
                       to_node=MatchSuffNode('b')))

    res = fragment.match_graph(G)
    assert len(res) == 1
    assert res[0].num_nodes() == 2
    assert res[0].num_edges() == 1
    assert set([node.name
                for node in res[0].nodes()]) == set(['test2a', 'test3b'])
Esempio n. 7
0
def test_match3(caplog):
    caplog.set_level(logging.DEBUG)
    G = Graph()
    G.add_node(Node("test1"))
    G.add_node(Node('test2'))
    G.add_edge(Edge('test1', 'test2'))
    G.add_edge(Edge('test1', Node('test3')))
    G.add_node(Node('test4'))
    G.add_edge(Edge('test2', 'test4'))
    # not the same - look here

    fragment = GraphMatcher()
    n1 = MatchNodeByName('test1')
    fragment.add_edge(
        MatchEdgeByIdx(from_node=n1, to_node=MatchNodeByName('test2')))
    fragment.add_edge(
        MatchEdgeByIdx(from_node=n1, to_node=MatchNodeByName('test3')))

    res = fragment.match_graph(G)
    assert len(res) == 1
    assert res[0].num_nodes() == 3
    assert res[0].num_edges() == 2
Esempio n. 8
0
def test_match4():
    G = Graph()
    G.add_node(Node("test1"))
    G.add_node(Node('test2'))
    G.add_node(Node('test3'))
    G.add_node(Node('test4'))
    G.add_node(Node('test5'))
    G.add_edge(Edge('test1', 'test2'))
    G.add_edge(Edge('test2', 'test4'))
    G.add_edge(Edge('test3', 'test4', to_idx=1))
    G.add_edge(Edge('test4', 'test5'))

    fragment = GraphMatcher()
    m2 = MatchNodeByName("test2")
    m3 = MatchNodeByName('test3')
    m4 = MatchNodeByName('test4')
    fragment.add_edge(MatchEdgeByIdx(from_node=m2, to_node=m4))
    fragment.add_edge(MatchEdgeByIdx(from_node=m3, to_node=m4, to_idx=1))

    res = fragment.match_graph(G)
    assert len(res) == 1
    assert res[0].num_nodes() == 3
    assert res[0].num_edges() == 2