def test_alternating_longest():
    g = nx.DiGraph()
    #   b c
    # a     d
    #    b
    add_nodes(g, ['a', 'b', 'c', 'b', 'd'])
    g.add_edges_from([(1, 2), (2, 3), (3, 5), (1, 4), (4, 5)])

    ex = N('a') + (N('b') | N('b') + N('c'))
    ex2 = N('a') + (N('b') + N('c') | N('b'))

    matches = search_all(ex, g)
    matches2 = search_all(ex2, g)

    assert matches2 == matches == [[1, 2, 3]]
Ejemplo n.º 2
0
 def get_matching_nncf_graph_pattern_io_list(
         self, expression: Expression) -> List[NNCFGraphPatternIO]:
     matched_node_key_sequences = search_all(self._nx_graph, expression)
     pattern_ios = [
         self._get_nncf_graph_pattern_io_list(match)
         for match in matched_node_key_sequences
     ]
     return pattern_ios
def test_graph_branching():
    g = nx.DiGraph()
    add_nodes(g, ['a', 'b', 'a', 'c'])
    g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4)])

    ex = N('a') + N('b')

    matches = search_all(ex, g)
    assert matches == [[1, 2]]
def test_two_matched():
    g = nx.DiGraph()
    add_nodes(g, ['a', 'b', 'c', 'a', 'b', 'c'])
    g.add_edges_from([(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)])

    ex = N('b') + N('c')

    matches = search_all(ex, g)
    assert matches == [[2, 3], [5, 6]]
def test_simple():
    g = nx.DiGraph()
    add_nodes(g, ['a', 'b', 'c', 'a'])
    g.add_edges_from([(1, 2), (2, 3), (3, 4)])

    ex = N('b') + N('c')

    matches = search_all(ex, g)
    assert matches == [[2, 3]]
def test_alternating():
    g = nx.DiGraph()
    add_nodes(g, ['a', 'b'])

    g.add_edges_from([(1, 2)])

    ex = N('a') + (N('a') | N('b'))

    matches = search_all(ex, g)
    assert matches == [[1, 2]]
def test_branching_expression2():
    g = nx.DiGraph()
    #   b
    # a e  d
    #   c
    add_nodes(g, ['a', 'b', 'c', 'd', 'e'])
    g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (1, 5), (5, 4)])
    c_ = (N('b') & N('c') & N('e'))
    n = N('a')
    node_expression = N('d')
    ex = n + c_ + node_expression

    matches = search_all(ex, g)
    assert matches == [[1, 2, 3, 5, 4]]
Ejemplo n.º 8
0
    def get_merged_original_graph_with_patterns(self,
                                                original_graph: NNCFGraph):
        import nncf.dynamic_graph.patterns as p
        from nncf.dynamic_graph.graph_matching import search_all

        pattern = p.LINEAR_OPS + p.ANY_BN_ACT_COMBO | p.LINEAR_OPS + p.ELTWISE_UNIFORM_OPS
        # pylint: disable=protected-access
        matches = search_all(original_graph._nx_graph, pattern)
        merged_graph = deepcopy(original_graph._nx_graph)
        nx.set_node_attributes(merged_graph, False, self.IS_MERGED_GRAPH_ATTR)
        for match in matches:
            if len(match) == 1:
                continue

            input_node_key = match[0]
            output_node_key = match[-1]
            in_edges = list(merged_graph.in_edges(input_node_key))
            out_edges = list(merged_graph.out_edges(output_node_key))

            in_edge_copies_dict = {}
            for in_edge_key in in_edges:
                in_edge_copies_dict[in_edge_key] = deepcopy(
                    merged_graph.edges[in_edge_key])
            out_edge_copies_dict = {}
            for out_edge_key in out_edges:
                out_edge_copies_dict[out_edge_key] = deepcopy(
                    merged_graph.edges[out_edge_key])

            merged_node_key = ""
            merged_nodes = []
            for node_key in match:
                merged_node_key += node_key + '\n'
                # pylint: disable=protected-access
                merged_nodes.append(original_graph._nx_graph.nodes[node_key])
                merged_graph.remove_node(node_key)
            merged_node_attrs = {
                NNCFGraph.KEY_NODE_ATTR: merged_node_key,
                self.NODES_GRAPH_ATTR: merged_nodes,
                self.IS_MERGED_GRAPH_ATTR: True
            }
            merged_graph.add_node(merged_node_key, **merged_node_attrs)
            for in_edge_key, in_edge_attrs in in_edge_copies_dict.items():
                merged_graph.add_edge(in_edge_key[0], merged_node_key,
                                      **in_edge_attrs)
            for out_edge_key, out_edge_attrs in out_edge_copies_dict.items():
                merged_graph.add_edge(merged_node_key, out_edge_key[1],
                                      **out_edge_attrs)

        return merged_graph
Ejemplo n.º 9
0
def test_branching_expression():
    g = nx.DiGraph()
    #   b
    # a   d
    #   c
    add_nodes(g, ['a', 'b', 'c', 'd'])
    g.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4)])
    c_ = (N('b') & N('c'))
    n = N('a')
    node_expression = N('d')
    ex = n + c_ + node_expression

    ex = N('a') + (N('b') & N('c')) + N('d')

    matches = search_all(g, ex)
    assert matches == [[1, 2, 3, 4]]
Ejemplo n.º 10
0
 def get_insertion_point_nodes_after_pattern(self, expression: Expression) -> List[NNCFNode]:
     matched_node_key_sequences = search_all(self._nx_graph, expression)
     ip_node_keys = self._find_insertion_points(matched_node_key_sequences)
     return [NNCFGraph._nx_node_to_nncf_node(self._nx_graph.nodes[key]) for key in ip_node_keys]
Ejemplo n.º 11
0
    def get_ip_graph_with_merged_hw_optimized_operations(
            self,
            hw_config: Optional[HWConfig] = None) -> 'InsertionPointGraph':
        #pylint:disable=too-many-branches
        merged_ip_graph = deepcopy(self)
        pattern = self._get_mergeable_operator_patterns(hw_config)
        from nncf.dynamic_graph.graph_matching import search_all
        matches = search_all(self._base_nx_graph, pattern)
        for match in matches:
            if len(match) == 1:
                continue

            input_node_key = match[0]
            output_node_key = match[-1]

            # If a subgraph has output edges in its middle, should skip merging it
            # Example:
            #       (conv2d)
            #          |------\
            #         (BN)    |
            #          |      |
            #        (RELU)   |
            #          |      |
            #        (cat)----/
            #          |
            #         ...

            has_breaking_output_edges = self._base_graph_match_has_breaking_output_edges(
                match)

            if has_breaking_output_edges:
                continue

            in_edges = list(self.in_edges(input_node_key))
            out_edges = list(self.out_edges(output_node_key))

            assert len(
                in_edges
            ) <= 1  # TODO: change to == 1 when input nodes are handled correctly

            if in_edges:
                in_edge_key = in_edges[0]
                in_edge_copy = deepcopy(self.edges[in_edge_key])
            out_edge_copies_dict = {}
            for out_edge_key in out_edges:
                out_edge_copies_dict[out_edge_key] = deepcopy(
                    self.edges[out_edge_key])

            conserved_edges_list = out_edges
            if in_edges:
                conserved_edges_list.append(in_edge_key)

            merged_node_attrs = deepcopy(self.nodes[input_node_key])
            merged_node_attrs[
                InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR] = set()
            merged_node_key = ""
            for node_key in match:
                ip_node_keys = self.nodes[node_key][
                    InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR]
                for ip_node_key in ip_node_keys:
                    should_keep_ip_node = False
                    for edge_key in conserved_edges_list:
                        if ip_node_key in edge_key:
                            should_keep_ip_node = True
                            break
                    if should_keep_ip_node:
                        merged_node_attrs[
                            InsertionPointGraph.
                            ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key)
                    else:
                        merged_ip_graph.remove_node(ip_node_key)
                merged_ip_graph.remove_node(node_key)
                merged_node_key += node_key + '\n'

            merged_ip_graph.add_node(merged_node_key, **merged_node_attrs)
            if in_edges:
                merged_ip_graph.add_edge(in_edge_key[0], merged_node_key,
                                         **in_edge_copy)
            for out_edge_key, out_edge_attrs in out_edge_copies_dict.items():
                merged_ip_graph.add_edge(merged_node_key, out_edge_key[1],
                                         **out_edge_attrs)

        return merged_ip_graph