def test_build_quantizer_propagation_state_graph_from_ip_graph(self):
        ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
        quant_prop_graph = QPSG(ip_graph)
        assert len(ip_graph.nodes) == len(quant_prop_graph.nodes)
        assert len(ip_graph.edges) == len(quant_prop_graph.edges)

        for ip_graph_node_key, ip_graph_node in ip_graph.nodes.items():
            qpg_node = quant_prop_graph.nodes[ip_graph_node_key]
            assert qpg_node[QPSG.NODE_TYPE_NODE_ATTR] == QPSG.ipg_node_type_to_qpsg_node_type(ip_graph_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR])
            qpg_node_type = qpg_node[QPSG.NODE_TYPE_NODE_ATTR]
            if qpg_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                assert qpg_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                assert qpg_node[QPSG.INSERTION_POINT_DATA_NODE_ATTR] == ip_graph_node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif qpg_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                assert not qpg_node[QPSG.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR]
                assert qpg_node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] == QuantizationTrait.NON_QUANTIZABLE
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            qpg_edge_data = quant_prop_graph.edges[from_node, to_node]
            assert not qpg_edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
            for key, value in edge_data.items():
                assert qpg_edge_data[key] == value

        quant_prop_graph.run_consistency_check()
    class BranchHandlingState0(RedundantQuantizerMergeTestStruct):
        ref_remaining_pq_positions = {
            InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=0),
            InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=1),
            InsertionPointGraph.get_pre_hook_node_key('C'),
            InsertionPointGraph.get_pre_hook_node_key('D')
        }
        operator_node_key_vs_trait_dict = {
            'I': QuantizationTrait.QUANTIZATION_AGNOSTIC,
            'C': QuantizationTrait.INPUTS_QUANTIZABLE,
            'G': QuantizationTrait.NON_QUANTIZABLE,
        }

        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            # This case will fail if, after going depth-first through the 'D' branch of the graph,
            # the merge traversal function state is not reset (which is incorrect behavior)
            # when starting to traverse the 'C' branch.
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=0))
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=1))
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('C'))
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('D'))
            return qpsg
    def test_add_propagating_quantizer(self, mock_qp_graph):
        ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)]

        target_node_key = "F"
        target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key)
        prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key)
        assert prop_quant.potential_quant_configs == ref_qconf_list
        assert prop_quant.current_location_node_key == target_ip_node_key
        assert prop_quant.affected_ip_nodes == {target_ip_node_key}
        assert prop_quant.last_accepting_location_node_key is None
        assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)}
        assert not prop_quant.propagation_path

        for node_key, node in mock_qp_graph.nodes.items():
            if node_key == target_node_key:
                assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            elif node_key == target_ip_node_key:
                assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
            else:
                assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

            target_ip_node = mock_qp_graph.nodes[target_ip_node_key]
            assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

        for from_node, to_node, edge_data in mock_qp_graph.edges.data():
            if (from_node, to_node) == (target_ip_node_key, target_node_key):
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            else:
                assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        with pytest.raises(RuntimeError):
            _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list,
                                                        InsertionPointGraph.get_post_hook_node_key(target_node_key))
Example #4
0
    def test_insertion_point_data_in_ip_nodes(self):
        # TODO: extend for modules
        mock_graph = nx.DiGraph()
        ref_op_exec_context = OperationExecutionContext(
            "baz", Scope.from_str("Test/Scope[foo]/bar"), 0, [None])
        node_attrs = {NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: ref_op_exec_context}

        node_key = 0
        mock_graph.add_node(node_key, **node_attrs)

        ip_graph = InsertionPointGraph(mock_graph)

        for node_key in mock_graph.nodes.keys():
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            pre_hook_ip_node = ip_graph.nodes[preds[0]]
            post_hook_ip_node = ip_graph.nodes[succs[0]]

            pre_hook_ip = pre_hook_ip_node[
                InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            post_hook_ip = post_hook_ip_node[
                InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            assert pre_hook_ip.insertion_type == InsertionType.OPERATOR_PRE_HOOK
            assert post_hook_ip.insertion_type == InsertionType.OPERATOR_POST_HOOK

            assert pre_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic
            assert post_hook_ip.ia_op_exec_context == ref_op_exec_context.input_agnostic
    class NoRedundancyState1(RedundantQuantizerMergeTestStruct):

        ref_remaining_pq_positions = {
            InsertionPointGraph.get_post_hook_node_key('B'),
            InsertionPointGraph.get_pre_hook_node_key('D')
        }
        operator_node_key_vs_trait_dict = {
            'B': QuantizationTrait.QUANTIZATION_AGNOSTIC,
            'C': QuantizationTrait.INPUTS_QUANTIZABLE,
            'D': QuantizationTrait.QUANTIZATION_AGNOSTIC,
            'E': QuantizationTrait.INPUTS_QUANTIZABLE
        }

        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                                  InsertionPointGraph.get_pre_hook_node_key('C'))
            pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                                  InsertionPointGraph.get_pre_hook_node_key('D'))
            qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig(per_channel=True)],
                                                     [None, None],
                                                     InsertionPointGraph.get_post_hook_node_key('B'))
            pq_3 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                                  InsertionPointGraph.get_pre_hook_node_key('E'))
            paths = get_edge_paths_for_propagation(qpsg,
                                                   InsertionPointGraph.get_pre_hook_node_key('D'),
                                                   InsertionPointGraph.get_pre_hook_node_key('E'))
            path = paths[0]
            qpsg.propagate_quantizer_via_path(pq_3, path)
            return qpsg
        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                                  InsertionPointGraph.get_pre_hook_node_key('F'))

            qpsg.propagate_quantizer_via_path(pq_1, [
                (InsertionPointGraph.get_post_hook_node_key('C'),
                 InsertionPointGraph.get_pre_hook_node_key('F'))
            ])
            _ = qpsg.add_propagating_quantizer([QuantizerConfig(bits=6)],
                                               InsertionPointGraph.get_pre_hook_node_key('F'))
            return qpsg
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('C'))
     pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('D'))
     qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig()],
                                              [None, None],
                                              InsertionPointGraph.get_post_hook_node_key('B'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('E'))
     return qpsg
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     # This case will fail if, after going depth-first through the 'D' branch of the graph,
     # the merge traversal function state is not reset (which is incorrect behavior)
     # when starting to traverse the 'C' branch.
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=0))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=1))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('C'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('D'))
     return qpsg
Example #9
0
    def test_merge_quantizer_into_path(self,
                                       merge_quantizer_into_path_test_struct):

        mock_graph = self.get_model_graph()
        ip_graph = InsertionPointGraph(mock_graph)
        quant_prop_graph = QPSG(ip_graph)

        for quantizers_test_struct in merge_quantizer_into_path_test_struct.start_set_quantizers:

            init_node_to_trait_and_configs_dict = quantizers_test_struct.init_node_to_trait_and_configs_dict
            starting_quantizer_ip_node = quantizers_test_struct.starting_quantizer_ip_node
            target_node = quantizers_test_struct.target_node_for_quantizer
            is_merged = quantizers_test_struct.is_merged
            prop_path = quantizers_test_struct.prop_path
            for node in quant_prop_graph.nodes.values():
                node[
                    QPSG.
                    QUANTIZATION_TRAIT_NODE_ATTR] = QuantizationTrait.QUANTIZATION_AGNOSTIC
            master_prop_quant = None
            merged_prop_quant = []
            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items(
            ):
                trait = trait_and_configs_tuple[0]
                qconfigs = trait_and_configs_tuple[1]
                quant_prop_graph.nodes[node_key][
                    QPSG.QUANTIZATION_TRAIT_NODE_ATTR] = trait
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    ip_node_key = InsertionPointGraph.get_pre_hook_node_key(
                        node_key)
                    prop_quant = quant_prop_graph.add_propagating_quantizer(
                        qconfigs, ip_node_key)
                    if ip_node_key == starting_quantizer_ip_node:
                        master_prop_quant = prop_quant

            path = get_edge_paths_for_propagation(quant_prop_graph,
                                                  target_node,
                                                  starting_quantizer_ip_node)
            master_prop_quant = quant_prop_graph.propagate_quantizer_via_path(
                master_prop_quant, path[0])
            if is_merged:
                merged_prop_quant.append((master_prop_quant, prop_path))

        for prop_quant, prop_path in merged_prop_quant:
            quant_prop_graph.merge_quantizer_into_path(prop_quant, prop_path)

        expected_quantizers_test_struct = merge_quantizer_into_path_test_struct.expected_set_quantizers
        self.check_final_state_qpsg(quant_prop_graph,
                                    expected_quantizers_test_struct)
 def mock_qp_graph():
     ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
     qpsg = QPSG(ip_graph)
     qpsg.skip_check = False
     yield qpsg
     if not qpsg.skip_check:
         qpsg.run_consistency_check()
    class NoConnectingPathsState1(RedundantQuantizerMergeTestStruct):
        ref_remaining_pq_positions = {
            InsertionPointGraph.get_pre_hook_node_key('C'),
            InsertionPointGraph.get_pre_hook_node_key('F')
        }
        operator_node_key_vs_trait_dict = {
            'C': QuantizationTrait.INPUTS_QUANTIZABLE,
            'F': QuantizationTrait.INPUTS_QUANTIZABLE
        }

        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('C'))
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('F'))
            return qpsg
    def __init__(self, ip_graph: InsertionPointGraph):
        super().__init__()
        ip_graph = deepcopy(ip_graph)
        self._created_prop_quantizer_counter = 0

        for node_key, node in ip_graph.nodes.items():
            qpg_node = {
                self.NODE_TYPE_NODE_ATTR:
                node[InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            }
            if node[InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.INSERTION_POINT:
                qpg_node[self.PROPAGATING_QUANTIZER_NODE_ATTR] = None
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.INSERTION_POINT_DATA_NODE_ATTR] = node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif node[
                    InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR:
                qpg_node[
                    self.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] = set()
                qpg_node[
                    self.
                    QUANTIZATION_TRAIT_NODE_ATTR] = QuantizationTrait.NON_QUANTIZABLE
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.OPERATOR_METATYPE_NODE_ATTR] = node[
                    InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR]
            self.add_node(node_key, **qpg_node)

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            edge_data[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
            self.add_edge(from_node, to_node, **edge_data)
    class NoRedundancyState0(RedundantQuantizerMergeTestStruct):
        ref_remaining_pq_positions = {
            InsertionPointGraph.get_post_hook_node_key('C'),
            InsertionPointGraph.get_pre_hook_node_key('F'),
        }
        operator_node_key_vs_trait_dict = {
            'F': QuantizationTrait.INPUTS_QUANTIZABLE
        }

        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                                  InsertionPointGraph.get_pre_hook_node_key('F'))

            qpsg.propagate_quantizer_via_path(pq_1, [
                (InsertionPointGraph.get_post_hook_node_key('C'),
                 InsertionPointGraph.get_pre_hook_node_key('F'))
            ])
            _ = qpsg.add_propagating_quantizer([QuantizerConfig(bits=6)],
                                               InsertionPointGraph.get_pre_hook_node_key('F'))
            return qpsg
Example #14
0
    def test_insertion_point_setup(self):
        # TODO: Change testing premises when module pre/post-op hooks and input/output nodes
        # are correctly handled
        mock_graph = get_two_branch_mock_model_graph()

        ip_graph = InsertionPointGraph(mock_graph)

        ref_node_len = 3 * len(
            mock_graph.nodes)  # 2 additional nodes per each operator node
        ref_edge_len = 3 * len(mock_graph.edges)

        assert len(ip_graph.nodes) == ref_node_len
        assert len(ip_graph.edges) == ref_edge_len

        for node_key, node in mock_graph.nodes.items():
            ip_graph_op_node = ip_graph.nodes[node_key]
            assert ip_graph_op_node[
                InsertionPointGraph.
                NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            assert len(preds) == 1
            assert len(succs) == 1
            pre_hook_ip_node_key = preds[0]
            post_hook_ip_node_key = succs[0]
            pre_hook_ip_node = ip_graph.nodes[preds[0]]
            post_hook_ip_node = ip_graph.nodes[succs[0]]
            pre_hook_ip_node_type = pre_hook_ip_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            post_hook_ip_node_type = post_hook_ip_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            assert pre_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT
            assert post_hook_ip_node_type == InsertionPointGraphNodeType.INSERTION_POINT
            ref_associated_ip_node_keys_set = {
                pre_hook_ip_node_key, post_hook_ip_node_key
            }
            assert ref_associated_ip_node_keys_set == ip_graph_op_node[
                InsertionPointGraph.ASSOCIATED_IP_NODE_KEYS_NODE_ATTR]
            original_neighbours = mock_graph.neighbors(node_key)
            for neighbour in original_neighbours:
                # IP node insertion should not disrupt the graph superstructure
                ip_graph_paths = list(
                    nx.all_simple_paths(ip_graph, node_key, neighbour))
                for path in ip_graph_paths:
                    path = path[1:-1]
                    for path_node_key in path:
                        node = ip_graph.nodes[path_node_key]
                        node_type = node[
                            InsertionPointGraph.NODE_TYPE_NODE_ATTR]
                        assert node_type == InsertionPointGraphNodeType.INSERTION_POINT

        for node_key, node in ip_graph.nodes.items():
            preds = list(ip_graph.predecessors(node_key))
            succs = list(ip_graph.successors(node_key))
            assert len(preds) != 0 or len(succs) != 0

        for from_node_key, to_node_key in ip_graph.edges.keys():
            assert from_node_key in ip_graph.nodes
            assert to_node_key in ip_graph.nodes
    class NoConnectingPathsState3(RedundantQuantizerMergeTestStruct):
        ref_remaining_pq_positions = {
            InsertionPointGraph.get_post_hook_node_key('B'),
            InsertionPointGraph.get_pre_hook_node_key('E')
        }
        operator_node_key_vs_trait_dict = {
            'B': QuantizationTrait.QUANTIZATION_AGNOSTIC,
            'C': QuantizationTrait.INPUTS_QUANTIZABLE,
            'D': QuantizationTrait.INPUTS_QUANTIZABLE,
        }

        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('C'))
            pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('D'))
            qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig()],
                                                     [None, None],
                                                     InsertionPointGraph.get_post_hook_node_key('B'))
            qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('E'))
            return qpsg
Example #16
0
    def test_get_ip_graph_with_merged_operations(self, mock_graph_factory,
                                                 dot_file_name):
        mock_graph = mock_graph_factory()
        ip_graph = InsertionPointGraph(mock_graph)
        merged_ip_graph = ip_graph.get_ip_graph_with_merged_hw_optimized_operations(
        )

        data_dir = TEST_ROOT / 'data/reference_graphs/pattern_merging'  # type: Path

        path_to_dot_file = data_dir / '{}.dot'.format(dot_file_name)

        # validate .dot file manually!
        if not path_to_dot_file.exists():
            if not data_dir.exists():
                data_dir.mkdir(parents=True)
            nx.drawing.nx_pydot.write_dot(merged_ip_graph,
                                          str(path_to_dot_file))

        load_graph = nx.drawing.nx_pydot.read_dot(str(path_to_dot_file))

        for key in load_graph.nodes.keys():
            key.replace(
                r'\\n', r'\n'
            )  # Somehow pydot mangles the \n characters while writing a .dot file

        sanitized_loaded_keys = [
            key.replace('\\n', '\n') for key in load_graph.nodes.keys()
        ]
        sanitized_loaded_edges = [(u.replace('\\n',
                                             '\n'), v.replace('\\n', '\n'))
                                  for u, v in nx.DiGraph(load_graph).edges]

        assert Counter(sanitized_loaded_keys) == Counter(
            list(merged_ip_graph.nodes.keys()))
        assert Counter(sanitized_loaded_edges) == Counter(
            list(merged_ip_graph.edges))
    def test_merge_quantizer_into_path(self, model_graph_qpsg, merge_quantizer_into_path_test_struct):
        quant_prop_graph = model_graph_qpsg

        for quantizers_test_struct in merge_quantizer_into_path_test_struct.start_set_quantizers:
            init_node_to_trait_and_configs_dict = quantizers_test_struct.init_node_to_trait_and_configs_dict
            starting_quantizer_ip_node = quantizers_test_struct.starting_quantizer_ip_node
            target_node = quantizers_test_struct.target_node_for_quantizer
            is_merged = quantizers_test_struct.is_merged
            prop_path = quantizers_test_struct.prop_path
            node_key_vs_trait_dict = {}  # type: Dict[str, QuantizationTrait]
            for node_key in quant_prop_graph.nodes:
                node_key_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC
            primary_prop_quant = None
            merged_prop_quant = []
            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                node_key_vs_trait_dict[node_key] = trait
            quant_prop_graph = self.mark_nodes_with_traits(quant_prop_graph, node_key_vs_trait_dict)

            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                qconfigs = trait_and_configs_tuple[1]
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key)
                    prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs,
                                                                            ip_node_key)
                    if ip_node_key == starting_quantizer_ip_node:
                        primary_prop_quant = prop_quant

            path = get_edge_paths_for_propagation(quant_prop_graph,
                                                  target_node,
                                                  starting_quantizer_ip_node)
            primary_prop_quant = quant_prop_graph.propagate_quantizer_via_path(primary_prop_quant,
                                                                               path[0])
            if is_merged:
                merged_prop_quant.append((primary_prop_quant, prop_path))

            quant_prop_graph.run_consistency_check()

        for prop_quant, prop_path in merged_prop_quant:
            quant_prop_graph.merge_quantizer_into_path(prop_quant, prop_path)
            quant_prop_graph.run_consistency_check()

        expected_quantizers_test_struct = merge_quantizer_into_path_test_struct.expected_set_quantizers
        self.check_final_state_qpsg(quant_prop_graph, expected_quantizers_test_struct)
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                           InsertionPointGraph.get_pre_hook_node_key('C'))
     pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                           InsertionPointGraph.get_pre_hook_node_key('D'))
     qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig(per_channel=True)],
                                              [None, None],
                                              InsertionPointGraph.get_post_hook_node_key('B'))
     pq_3 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('E'))
     paths = get_edge_paths_for_propagation(qpsg,
                                            InsertionPointGraph.get_pre_hook_node_key('D'),
                                            InsertionPointGraph.get_pre_hook_node_key('E'))
     path = paths[0]
     qpsg.propagate_quantizer_via_path(pq_3, path)
     return qpsg
 def model_graph_qpsg(self):
     mock_graph = self.get_model_graph()
     ip_graph = InsertionPointGraph(mock_graph)
     quant_prop_graph = QPSG(ip_graph)
     return quant_prop_graph
Example #20
0
    def __init__(self, ip_graph: InsertionPointGraph, ignored_scopes=None):
        super().__init__()
        ip_graph = deepcopy(ip_graph)
        self._created_prop_quantizer_counter = 0

        self._ignored_scopes = deepcopy(ignored_scopes)
        self.ignored_node_keys = []

        barrier_node_extra_edges = []
        for node_key, node in ip_graph.nodes.items():
            qpg_node = {
                self.NODE_TYPE_NODE_ATTR: \
                    self.ipg_node_type_to_qpsg_node_type(node[InsertionPointGraph.NODE_TYPE_NODE_ATTR])}
            if node[InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.INSERTION_POINT:
                qpg_node[self.PROPAGATING_QUANTIZER_NODE_ATTR] = None
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.INSERTION_POINT_DATA_NODE_ATTR] = node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif node[
                    InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR:
                qpg_node[
                    self.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] = set()
                qpg_node[
                    self.
                    QUANTIZATION_TRAIT_NODE_ATTR] = QuantizationTrait.NON_QUANTIZABLE
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.OPERATOR_METATYPE_NODE_ATTR] = node[
                    InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR]
                scope_node = str(
                    node[InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR][
                        NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic)

                if in_scope_list(scope_node, self._ignored_scopes):
                    self.ignored_node_keys.append(node_key)
                    qpg_node_barrier = {
                        self.NODE_TYPE_NODE_ATTR:
                        QuantizerPropagationStateGraphNodeType.
                        AUXILIARY_BARRIER,
                        'label':
                        QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX
                    }
                    barrier_node_key = self.get_barrier_node_key(node_key)
                    self.add_node(barrier_node_key, **qpg_node_barrier)
                    barrier_node_extra_edges.append(
                        (barrier_node_key, node_key))

            self.add_node(node_key, **qpg_node)

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            edge_data[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
            self.add_edge(from_node, to_node, **edge_data)

        for u_node_key, v_node_key in barrier_node_extra_edges:
            edge_attr = {
                QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR:
                []
            }
            next_v_node_key = list(
                self.succ[v_node_key].keys())[0]  # POST HOOK v
            self.add_edge(v_node_key, u_node_key, **edge_attr)
            self.add_edge(u_node_key, next_v_node_key, **edge_attr)
            self.remove_edge(v_node_key, next_v_node_key)
class TestQuantizerPropagationStateGraph:
    #pylint:disable=too-many-public-methods
    @staticmethod
    @pytest.fixture()
    def mock_qp_graph():
        ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
        qpsg = QPSG(ip_graph)
        qpsg.skip_check = False
        yield qpsg
        if not qpsg.skip_check:
            qpsg.run_consistency_check()

    def test_build_quantizer_propagation_state_graph_from_ip_graph(self):
        ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
        quant_prop_graph = QPSG(ip_graph)
        assert len(ip_graph.nodes) == len(quant_prop_graph.nodes)
        assert len(ip_graph.edges) == len(quant_prop_graph.edges)

        for ip_graph_node_key, ip_graph_node in ip_graph.nodes.items():
            qpg_node = quant_prop_graph.nodes[ip_graph_node_key]
            assert qpg_node[QPSG.NODE_TYPE_NODE_ATTR] == QPSG.ipg_node_type_to_qpsg_node_type(ip_graph_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR])
            qpg_node_type = qpg_node[QPSG.NODE_TYPE_NODE_ATTR]
            if qpg_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                assert qpg_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                assert qpg_node[QPSG.INSERTION_POINT_DATA_NODE_ATTR] == ip_graph_node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif qpg_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                assert not qpg_node[QPSG.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR]
                assert qpg_node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] == QuantizationTrait.NON_QUANTIZABLE
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            qpg_edge_data = quant_prop_graph.edges[from_node, to_node]
            assert not qpg_edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
            for key, value in edge_data.items():
                assert qpg_edge_data[key] == value

        quant_prop_graph.run_consistency_check()

    def test_add_propagating_quantizer(self, mock_qp_graph):
        ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)]

        target_node_key = "F"
        target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key)
        prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key)
        assert prop_quant.potential_quant_configs == ref_qconf_list
        assert prop_quant.current_location_node_key == target_ip_node_key
        assert prop_quant.affected_ip_nodes == {target_ip_node_key}
        assert prop_quant.last_accepting_location_node_key is None
        assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)}
        assert not prop_quant.propagation_path

        for node_key, node in mock_qp_graph.nodes.items():
            if node_key == target_node_key:
                assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            elif node_key == target_ip_node_key:
                assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
            else:
                assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

            target_ip_node = mock_qp_graph.nodes[target_ip_node_key]
            assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

        for from_node, to_node, edge_data in mock_qp_graph.edges.data():
            if (from_node, to_node) == (target_ip_node_key, target_node_key):
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            else:
                assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        with pytest.raises(RuntimeError):
            _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list,
                                                        InsertionPointGraph.get_post_hook_node_key(target_node_key))

    START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES = [
        # Non-branching case - starting from "E" pre-hook
        (InsertionPointGraph.get_pre_hook_node_key("E"),
         [[(InsertionPointGraph.get_post_hook_node_key("C"), InsertionPointGraph.get_pre_hook_node_key("E"))]]),

        # Non-branching case - starting from "C" post-hook
        (InsertionPointGraph.get_post_hook_node_key("C"),
         [[("C", InsertionPointGraph.get_post_hook_node_key("C")),
           (InsertionPointGraph.get_pre_hook_node_key("C"), "C")]]),

        # Branching case - starting from "F" pre-hook port 0
        (InsertionPointGraph.get_pre_hook_node_key("F"),
         [[(InsertionPointGraph.get_post_hook_node_key("D"), InsertionPointGraph.get_pre_hook_node_key("F"))]]),

        # Branching case - starting from "F" pre-hook port 1
        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         [[(InsertionPointGraph.get_post_hook_node_key("E"),
            InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1))]]),

    ]

    @staticmethod
    @pytest.fixture(params=START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES)
    def start_ip_node_and_path_to_dominating_node(request):
        return request.param

    def test_get_paths_to_immediately_dominating_insertion_points(self, start_ip_node_and_path_to_dominating_node,
                                                                  mock_qp_graph):
        start_node = start_ip_node_and_path_to_dominating_node[0]
        ref_paths = start_ip_node_and_path_to_dominating_node[1]
        test_paths = mock_qp_graph.get_paths_to_immediately_dominating_insertion_points(start_node)
        def get_cat_path_list(path_list: List[List[Tuple[str, str]]]):
            str_paths = [[str(edge[0]) +' -> ' + str(edge[1]) for edge in path] for path in path_list]
            cat_paths = [';'.join(path) for path in str_paths]
            return cat_paths

        assert Counter(get_cat_path_list(ref_paths)) == Counter(get_cat_path_list(test_paths))


    START_TARGET_NODES = [
        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("F")),

        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         InsertionPointGraph.get_pre_hook_node_key("E")),

        (InsertionPointGraph.get_pre_hook_node_key("F"),
         InsertionPointGraph.get_post_hook_node_key("B")),
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_NODES)
    def start_target_nodes(request):
        return request.param

    @pytest.mark.dependency(name="propagate_via_path")
    def test_quantizers_can_propagate_via_path(self, start_target_nodes, mock_qp_graph):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        rev_paths = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)

        for path in rev_paths:
            working_graph = deepcopy(mock_qp_graph)
            ref_prop_quant = working_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key)
            ref_affected_edges = deepcopy(ref_prop_quant.affected_edges)
            ref_affected_edges.update(set(path))
            ref_affected_ip_nodes = deepcopy(ref_prop_quant.affected_ip_nodes)
            prop_quant = working_graph.propagate_quantizer_via_path(ref_prop_quant, path)
            final_node_key, _ = path[-1]
            for from_node_key, to_node_key in path:
                edge_data = working_graph.edges[from_node_key, to_node_key]
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [ref_prop_quant]
                to_node = working_graph.nodes[to_node_key]
                if to_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert to_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                from_node = working_graph.nodes[from_node_key]
                if from_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    ref_affected_ip_nodes.add(from_node_key)
            working_graph.run_consistency_check()

            final_node_key, _ = path[-1]
            final_node = working_graph.nodes[final_node_key]
            assert final_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == ref_prop_quant

            assert prop_quant.current_location_node_key == final_node_key
            assert prop_quant.propagation_path == path
            assert prop_quant.affected_edges == ref_affected_edges
            assert prop_quant.affected_ip_nodes == ref_affected_ip_nodes

    START_TARGET_ACCEPTING_NODES = [
        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("F"),
         InsertionPointGraph.get_post_hook_node_key("F")),

        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         InsertionPointGraph.get_pre_hook_node_key("C"),
         InsertionPointGraph.get_post_hook_node_key("C")),

        (InsertionPointGraph.get_pre_hook_node_key("D"),
         InsertionPointGraph.get_pre_hook_node_key("B"),
         InsertionPointGraph.get_post_hook_node_key("B")),
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_ACCEPTING_NODES)
    def start_target_accepting_nodes(request):
        return request.param

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_backtrack_propagation_until_accepting_location(self, start_target_accepting_nodes, mock_qp_graph):
        start_ip_node_key = start_target_accepting_nodes[0]
        target_ip_node_key = start_target_accepting_nodes[1]
        forced_last_accepting_location = start_target_accepting_nodes[2]

        prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                             start_ip_node_key)
        ref_affected_edges = deepcopy(prop_quant.affected_edges)

        # Here, the tested graph should have such a structure that there is only one path from target to start
        path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]
        prop_quant = mock_qp_graph.propagate_quantizer_via_path(prop_quant, path)
        prop_quant.last_accepting_location_node_key = forced_last_accepting_location
        if forced_last_accepting_location is not None:
            resulting_path = get_edge_paths_for_propagation(mock_qp_graph,
                                                            forced_last_accepting_location, start_ip_node_key)[0]
            ref_affected_edges.update(set(resulting_path))

        prop_quant = mock_qp_graph.backtrack_propagation_until_accepting_location(prop_quant)

        assert prop_quant.current_location_node_key == forced_last_accepting_location
        assert prop_quant.affected_edges == ref_affected_edges
        assert prop_quant.propagation_path == resulting_path

        target_node = mock_qp_graph.nodes[target_ip_node_key]
        accepting_node = mock_qp_graph.nodes[forced_last_accepting_location]
        if forced_last_accepting_location != target_ip_node_key:
            assert target_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
            assert target_ip_node_key not in prop_quant.affected_ip_nodes
        assert accepting_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_clone_propagating_quantizer(self, mock_qp_graph, start_target_nodes):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]

        ref_prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                 start_ip_node_key)

        prop_quant = mock_qp_graph.propagate_quantizer_via_path(ref_prop_quant, rev_path)

        cloned_prop_quant = mock_qp_graph.clone_propagating_quantizer(prop_quant)

        assert cloned_prop_quant.affected_ip_nodes == prop_quant.affected_ip_nodes
        assert cloned_prop_quant.affected_edges == prop_quant.affected_edges
        assert cloned_prop_quant.propagation_path == prop_quant.propagation_path
        assert cloned_prop_quant.current_location_node_key == prop_quant.current_location_node_key

        for ip_node_key in prop_quant.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert cloned_prop_quant in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
        for from_node_key, to_node_key in prop_quant.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert cloned_prop_quant in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        # The cloned quantizer had not been put into any IP (cannot have multiple PQs in one IP right now)
        mock_qp_graph.skip_check = True

    START_TARGET_NODES_FOR_TWO_QUANTIZERS = [
        (InsertionPointGraph.get_pre_hook_node_key("E"),
         InsertionPointGraph.get_post_hook_node_key("C"),
         InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("C"),
         InsertionPointGraph.get_post_hook_node_key("A"),
         InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("D")),

        # Simulated quantizer merging result
        (InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_pre_hook_node_key("E"),
         InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("D"))
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_NODES_FOR_TWO_QUANTIZERS)
    def start_target_nodes_for_two_quantizers(request):
        return request.param

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_remove_propagating_quantizer(self, mock_qp_graph, start_target_nodes_for_two_quantizers):
        start_ip_node_key_remove = start_target_nodes_for_two_quantizers[0]
        target_ip_node_key_remove = start_target_nodes_for_two_quantizers[1]

        start_ip_node_key_keep = start_target_nodes_for_two_quantizers[2]
        target_ip_node_key_keep = start_target_nodes_for_two_quantizers[3]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path_remove = get_edge_paths_for_propagation(mock_qp_graph,
                                                         target_ip_node_key_remove,
                                                         start_ip_node_key_remove)[0]
        rev_path_keep = get_edge_paths_for_propagation(mock_qp_graph,
                                                       target_ip_node_key_keep,
                                                       start_ip_node_key_keep)[0]

        prop_quant_to_remove = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                       start_ip_node_key_remove)
        prop_quant_to_remove = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_remove, rev_path_remove)

        prop_quant_to_keep = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key_keep)

        prop_quant_to_keep = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_keep, rev_path_keep)

        affected_ip_nodes = deepcopy(prop_quant_to_remove.affected_ip_nodes)
        affected_op_nodes = deepcopy(prop_quant_to_remove.affected_operator_nodes)
        affected_edges = deepcopy(prop_quant_to_keep.affected_edges)
        last_location = prop_quant_to_remove.current_location_node_key
        ref_quant_to_keep_state_dict = deepcopy(prop_quant_to_keep.__dict__)

        mock_qp_graph.remove_propagating_quantizer(prop_quant_to_remove)

        last_node = mock_qp_graph.nodes[last_location]
        assert last_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

        for ip_node_key in affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for op_node_key in affected_op_nodes:
            node = mock_qp_graph.nodes[op_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_remove not in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        assert prop_quant_to_keep.__dict__ == ref_quant_to_keep_state_dict

        for ip_node_key in prop_quant_to_keep.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_keep in node[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in prop_quant_to_keep.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_keep in edge[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

    QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES = [
        (["D", "E", "F"],
         {
             "B": {"D", "E"},
             InsertionPointGraph.get_pre_hook_node_key("B"): {"D", "E"},
             "E": {"F"},
             InsertionPointGraph.get_post_hook_node_key("D"): {"F"},
             "A": {"D", "E"},
             InsertionPointGraph.get_pre_hook_node_key("G"): set()
         }),
        (["C", "E", "H"],
         {
             InsertionPointGraph.get_pre_hook_node_key("C"): {"C"},
             InsertionPointGraph.get_post_hook_node_key("C"): {"E"},
             "D": {"H"},
             InsertionPointGraph.get_pre_hook_node_key("B"): {"C", "H"}, # corner case - has a branch without quantizers
             InsertionPointGraph.get_post_hook_node_key("H"): set()
         })
    ]

    @staticmethod
    @pytest.fixture(params=QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES)
    def dominated_nodes_test_struct(request):
        return request.param

    @staticmethod
    def mark_nodes_with_traits(qpsg: QPSG, node_keys_vs_traits_dict: Dict[str, QuantizationTrait]) -> QPSG:
        for node_key, node in qpsg.nodes.items():
            if node_key in node_keys_vs_traits_dict:
                node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] = node_keys_vs_traits_dict[node_key]
        return qpsg

    def test_get_quantizable_op_nodes_immediately_dominated_by_node(self, mock_qp_graph, dominated_nodes_test_struct):
        nodes_to_mark_as_quantizable = dominated_nodes_test_struct[0]

        node_keys_vs_trait_dict = {}
        for node_key in mock_qp_graph.nodes:
            node_keys_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC

        traits_to_mark_with = [QuantizationTrait.INPUTS_QUANTIZABLE,
                               QuantizationTrait.NON_QUANTIZABLE]

        for trait in traits_to_mark_with:
            for node_key in nodes_to_mark_as_quantizable:
                node_keys_vs_trait_dict[node_key] = trait

            mock_qp_graph = self.mark_nodes_with_traits(mock_qp_graph, node_keys_vs_trait_dict)
            for start_node_key, ref_dominated_quantizable_nodes_set in dominated_nodes_test_struct[1].items():
                dominated_quantizable_nodes_list = \
                    mock_qp_graph.get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(start_node_key)
                assert set(dominated_quantizable_nodes_list) == ref_dominated_quantizable_nodes_set

    @staticmethod
    def get_model_graph():
        mock_node_attrs = get_mock_nncf_node_attrs()
        mock_graph = nx.DiGraph()

        #     (A)
        #      |
        #     (B)
        #   /     \
        # (C)     (D)
        #  |       |
        # (F)     (E)
        #
        #

        node_keys = ['A', 'B', 'C', 'D', 'E', 'F']
        for node_key in node_keys:
            mock_graph.add_node(node_key, **mock_node_attrs)

        mock_graph.add_edges_from([('A', 'B'), ('B', 'C'), ('B', 'D'), ('D', 'E'), ('C', 'F')])
        mark_input_ports_lexicographically_based_on_input_node_key(mock_graph)
        return mock_graph


    StateQuantizerTestStruct = namedtuple('StateQuantizerTestStruct',
                                          ('init_node_to_trait_and_configs_dict',
                                           'starting_quantizer_ip_node',
                                           'target_node_for_quantizer',
                                           'is_merged',
                                           'prop_path'))

    SetQuantizersTestStruct = namedtuple('SetQuantizersTestStruct',
                                         ('start_set_quantizers',
                                          'expected_set_quantizers'))

    MERGE_QUANTIZER_INTO_PATH_TEST_CASES = [
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'),
                    is_merged=True,
                    prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'),
                                InsertionPointGraph.get_pre_hook_node_key('C'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'PRE HOOK B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                                       [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        ),
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=True,
                    prop_path=[('B', InsertionPointGraph.get_post_hook_node_key('B'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        ),
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'),
                    is_merged=True,
                    prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'),
                                InsertionPointGraph.get_pre_hook_node_key('C'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        )
    ]

    @staticmethod
    @pytest.fixture(params=MERGE_QUANTIZER_INTO_PATH_TEST_CASES)
    def merge_quantizer_into_path_test_struct(request):
        return request.param

    @pytest.fixture
    def model_graph_qpsg(self):
        mock_graph = self.get_model_graph()
        ip_graph = InsertionPointGraph(mock_graph)
        quant_prop_graph = QPSG(ip_graph)
        return quant_prop_graph

    def test_merge_quantizer_into_path(self, model_graph_qpsg, merge_quantizer_into_path_test_struct):
        quant_prop_graph = model_graph_qpsg

        for quantizers_test_struct in merge_quantizer_into_path_test_struct.start_set_quantizers:
            init_node_to_trait_and_configs_dict = quantizers_test_struct.init_node_to_trait_and_configs_dict
            starting_quantizer_ip_node = quantizers_test_struct.starting_quantizer_ip_node
            target_node = quantizers_test_struct.target_node_for_quantizer
            is_merged = quantizers_test_struct.is_merged
            prop_path = quantizers_test_struct.prop_path
            node_key_vs_trait_dict = {}  # type: Dict[str, QuantizationTrait]
            for node_key in quant_prop_graph.nodes:
                node_key_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC
            primary_prop_quant = None
            merged_prop_quant = []
            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                node_key_vs_trait_dict[node_key] = trait
            quant_prop_graph = self.mark_nodes_with_traits(quant_prop_graph, node_key_vs_trait_dict)

            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                qconfigs = trait_and_configs_tuple[1]
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key)
                    prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs,
                                                                            ip_node_key)
                    if ip_node_key == starting_quantizer_ip_node:
                        primary_prop_quant = prop_quant

            path = get_edge_paths_for_propagation(quant_prop_graph,
                                                  target_node,
                                                  starting_quantizer_ip_node)
            primary_prop_quant = quant_prop_graph.propagate_quantizer_via_path(primary_prop_quant,
                                                                               path[0])
            if is_merged:
                merged_prop_quant.append((primary_prop_quant, prop_path))

            quant_prop_graph.run_consistency_check()

        for prop_quant, prop_path in merged_prop_quant:
            quant_prop_graph.merge_quantizer_into_path(prop_quant, prop_path)
            quant_prop_graph.run_consistency_check()

        expected_quantizers_test_struct = merge_quantizer_into_path_test_struct.expected_set_quantizers
        self.check_final_state_qpsg(quant_prop_graph, expected_quantizers_test_struct)

    @staticmethod
    def check_final_state_qpsg(final_quant_prop_graph, expected_quantizers_test_struct):
        for quantizer_param in expected_quantizers_test_struct:
            from_node_key = quantizer_param.target_node_for_quantizer
            expected_prop_path = set()
            target_node = quantizer_param.target_node_for_quantizer
            for start_node in quantizer_param.starting_quantizer_ip_node:
                added_path = get_edge_paths_for_propagation(final_quant_prop_graph,
                                                            target_node,
                                                            start_node)
                expected_prop_path.update(added_path[0])

            quantizer = final_quant_prop_graph.nodes[from_node_key][QPSG.PROPAGATING_QUANTIZER_NODE_ATTR]

            assert quantizer is not None

            for from_node_key, to_node_key in expected_prop_path:
                assert quantizer in final_quant_prop_graph.edges[(from_node_key, to_node_key)][
                    QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                from_node = final_quant_prop_graph.nodes[from_node_key]
                from_node_type = from_node[QPSG.NODE_TYPE_NODE_ATTR]
                if from_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    # pylint:disable=line-too-long
                    assert quantizer in final_quant_prop_graph.nodes[from_node_key][QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

            assert quantizer.affected_edges == expected_prop_path
Example #22
0
 def mock_qp_graph():
     ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
     yield QPSG(ip_graph)
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('B'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('E'))
     return qpsg