Ejemplo n.º 1
0
def test_ops_combination_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='c')
    ref_graph.add_edge('1', '2')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['1', '2']]

    pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='c')
    ref_graph.add_edge('1', '2')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['1', '2']]

    pattern = (TestPattern.first_pattern + TestPattern.second_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='c')
    ref_graph.add_node('3', type='e')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('1', '3')
    ref_graph.add_edge('2', '3')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)

    assert matches == [['1', '2', '3']]
Ejemplo n.º 2
0
def get_merged_original_graph_with_pattern(
        orig_graph: nx.DiGraph) -> nx.DiGraph:
    """
    :param orig_graph: Original graph of model
    :return: Graph with merged nodes by patterns
    """
    # pylint: disable=protected-access
    pattern_fusing_graph = PT_HW_FUSED_PATTERNS.get_full_pattern_graph()
    matches = find_subgraphs_matching_pattern(orig_graph, pattern_fusing_graph)
    merged_graph = deepcopy(orig_graph)
    nx.set_node_attributes(merged_graph, False, SearchGraph.IS_DUMMY_NODE_ATTR)
    nx.set_node_attributes(merged_graph, False,
                           SearchGraph.IS_MERGED_NODE_ATTR)
    nx.set_node_attributes(merged_graph, None,
                           SearchGraph.ACTIVATION_INPUT_SHAPE_ATTR)
    nx.set_node_attributes(merged_graph, None,
                           SearchGraph.ACTIVATION_OUTPUT_SHAPE_ATTR)
    for match in matches:
        if len(match) == 1:
            continue
        input_node_key = match[0]
        output_node_key = match[-1]
        in_edges = list(merged_graph.in_edges(input_node_key))
        out_edges = list(merged_graph.out_edges(output_node_key))

        in_edge_copies_dict = {}
        for in_edge_key in in_edges:
            in_edge_copies_dict[in_edge_key] = deepcopy(
                merged_graph.edges[in_edge_key])
        out_edge_copies_dict = {}
        for out_edge_key in out_edges:
            out_edge_copies_dict[out_edge_key] = deepcopy(
                merged_graph.edges[out_edge_key])

        merged_node_key = ""
        merged_nodes = []
        type_list = []
        for node_key in match:
            attrs = orig_graph.nodes[node_key]
            merged_node_key += str(
                attrs['id']) + ' ' + attrs[SearchGraph.TYPE_NODE_ATTR] + '  '
            # pylint: disable=protected-access
            merged_nodes.append(orig_graph.nodes[node_key])
            merged_graph.remove_node(node_key)
            type_list.append(attrs[SearchGraph.TYPE_NODE_ATTR])
        merged_node_attrs = {
            SearchGraph.KEY_NODE_ATTR: merged_node_key,
            SearchGraph.IS_MERGED_NODE_ATTR: True,
            SearchGraph.TYPE_NODE_ATTR: type_list,
            SearchGraph.MERGED_NODES_NODE_ATTR: merged_nodes,
        }
        merged_graph.add_node(merged_node_key, **merged_node_attrs)
        for in_edge_key, in_edge_attrs in in_edge_copies_dict.items():
            merged_graph.add_edge(in_edge_key[0], merged_node_key,
                                  **in_edge_attrs)
        for out_edge_key, out_edge_attrs in out_edge_copies_dict.items():
            merged_graph.add_edge(merged_node_key, out_edge_key[1],
                                  **out_edge_attrs)

    return merged_graph
Ejemplo n.º 3
0
def test_matches_with_non_pattern_node_type():
    pattern = TestPattern.forth_pattern + TestPattern.first_pattern + TestPattern.second_pattern

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='a')
    ref_graph.add_node('3', type='c')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '3')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['2', '3']]

    pattern = TestPattern.forth_pattern + TestPattern.first_pattern + \
              TestPattern.second_pattern + TestPattern.forth_pattern

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='a')
    ref_graph.add_node('3', type='c')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '3')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['2', '3']]

    pattern = TestPattern.pattern_with_non_pattern_nodes

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='b')
    ref_graph.add_node('3', type='c')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '3')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert not matches

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='b')
    ref_graph.add_node('4', type='a')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '4')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert not matches

    ref_graph = create_graph_with_many_nodes()
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['1', '2', '4', '3', '5', '6']]
Ejemplo n.º 4
0
def test_matches_with_any_pattern_node_type():
    pattern = TestPattern.pattern_with_any_pattern_nodes

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='b')
    ref_graph.add_node('3', type='c')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '3')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert not matches

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='b')
    ref_graph.add_node('4', type='a')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '4')
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert not matches

    ref_graph = create_graph_with_many_nodes()
    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    assert matches == [['7', '1', '2', '4', '8', '3', '5', '9', '6']]
Ejemplo n.º 5
0
    def get_merged_original_graph_with_patterns(self,
                                                original_graph: PTNNCFGraph):
        pattern = PT_HW_FUSED_PATTERNS.get_full_pattern_graph()
        # pylint: disable=protected-access
        matches = find_subgraphs_matching_pattern(original_graph._nx_graph,
                                                  pattern)
        merged_graph = deepcopy(original_graph._nx_graph)
        nx.set_node_attributes(merged_graph, False, self.IS_MERGED_GRAPH_ATTR)
        for match in matches:
            if len(match) == 1:
                continue

            input_node_key = match[0]
            output_node_key = match[-1]
            in_edges = list(merged_graph.in_edges(input_node_key))
            out_edges = list(merged_graph.out_edges(output_node_key))

            in_edge_copies_dict = {}
            for in_edge_key in in_edges:
                in_edge_copies_dict[in_edge_key] = deepcopy(
                    merged_graph.edges[in_edge_key])
            out_edge_copies_dict = {}
            for out_edge_key in out_edges:
                out_edge_copies_dict[out_edge_key] = deepcopy(
                    merged_graph.edges[out_edge_key])

            merged_node_key = ""
            merged_nodes = []
            for node_key in match:
                merged_node_key += node_key + '\n'
                # pylint: disable=protected-access
                merged_nodes.append(original_graph._nx_graph.nodes[node_key])
                merged_graph.remove_node(node_key)
            merged_node_attrs = {
                PTNNCFGraph.KEY_NODE_ATTR: merged_node_key,
                self.NODES_GRAPH_ATTR: merged_nodes,
                self.IS_MERGED_GRAPH_ATTR: True
            }
            merged_graph.add_node(merged_node_key, **merged_node_attrs)
            for in_edge_key, in_edge_attrs in in_edge_copies_dict.items():
                merged_graph.add_edge(in_edge_key[0], merged_node_key,
                                      **in_edge_attrs)
            for out_edge_key, out_edge_attrs in out_edge_copies_dict.items():
                merged_graph.add_edge(merged_node_key, out_edge_key[1],
                                      **out_edge_attrs)

        return merged_graph
Ejemplo n.º 6
0
def test_two_matches():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern

    ref_graph = nx.DiGraph()
    ref_graph.add_node('1', type='a')
    ref_graph.add_node('2', type='c')
    ref_graph.add_node('3', type='e')
    ref_graph.add_node('4', type='c')
    ref_graph.add_node('5', type='a')
    ref_graph.add_node('6', type='d')
    ref_graph.add_edge('1', '2')
    ref_graph.add_edge('2', '3')
    ref_graph.add_edge('5', '6')

    matches = find_subgraphs_matching_pattern(ref_graph, pattern)
    matches.sort()
    assert matches == [['1', '2'], ['5', '6']]
Ejemplo n.º 7
0
    def get_ip_graph_with_merged_hw_optimized_operations(self,
                                                         full_fusing_pattern: GraphPattern) \
            -> 'InsertionPointGraph':
        """
        Returns an InsertionPointGraph in which the nodes that match a HW-specific list of patterns are fused into a
        single node; the resulting InsertionPointGraph no longer has accessible the pre- and post-hooks that were
        located in  the middle of the fused pattern.
        :param full_fusing_pattern: The GraphPatttern object representing a composition of fusing pattern variants.
        :return: The InsertionPointGraph with nodes fused according to pattern matching.
        """
        # pylint:disable=too-many-branches
        merged_ip_graph = deepcopy(self)
        matches = find_subgraphs_matching_pattern(self._base_nx_graph,
                                                  full_fusing_pattern)
        for match in matches:
            if len(match) == 1:
                continue

            input_node_key = match[0]
            output_node_key = match[-1]

            in_edges = list(self.in_edges(input_node_key))
            out_edges = list(self.out_edges(output_node_key))

            in_edge_copies_dict = {}
            for in_edge_key in in_edges:
                in_edge_copies_dict[in_edge_key] = deepcopy(
                    self.edges[in_edge_key])
            out_edge_copies_dict = {}
            for out_edge_key in out_edges:
                out_edge_copies_dict[out_edge_key] = deepcopy(
                    self.edges[out_edge_key])

            conserved_edges_list = out_edges + in_edges

            merged_node_attrs = deepcopy(self.nodes[input_node_key])
            merged_node_attrs[
                InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR] = set()
            merged_node_attrs[InsertionPointGraph.IS_MERGED_NODE_ATTR] = True
            merged_node_key = ""
            merged_nncf_nodes = []
            for node_key in match:
                ip_node_keys = self.nodes[node_key][
                    InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR]
                for ip_node_key in ip_node_keys:
                    should_keep_ip_node = False
                    for edge_key in conserved_edges_list:
                        if ip_node_key in edge_key:
                            should_keep_ip_node = True
                            break
                    if should_keep_ip_node:
                        merged_node_attrs[
                            InsertionPointGraph.
                            ASSOCIATED_IP_NODE_KEYS_NODE_ATTR].add(ip_node_key)
                    else:
                        merged_ip_graph.remove_node(ip_node_key)
                merged_nncf_nodes.append(self.nodes[node_key][
                    InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR])
                merged_ip_graph.remove_node(node_key)
                merged_node_key += node_key + '\n'

            merged_node_attrs[
                InsertionPointGraph.
                MERGED_NNCF_NODE_LIST_NODE_ATTR] = merged_nncf_nodes
            merged_ip_graph.add_node(merged_node_key, **merged_node_attrs)
            for in_edge_key, in_edge_attrs in in_edge_copies_dict.items():
                merged_ip_graph.add_edge(in_edge_key[0], merged_node_key,
                                         **in_edge_attrs)
            for out_edge_key, out_edge_attrs in out_edge_copies_dict.items():
                merged_ip_graph.add_edge(merged_node_key, out_edge_key[1],
                                         **out_edge_attrs)

        return merged_ip_graph