def test_add_propagating_quantizer(self, mock_qp_graph):
        ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)]

        target_node_key = "F"
        target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key)
        prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key)
        assert prop_quant.potential_quant_configs == ref_qconf_list
        assert prop_quant.current_location_node_key == target_ip_node_key
        assert prop_quant.affected_ip_nodes == {target_ip_node_key}
        assert prop_quant.last_accepting_location_node_key is None
        assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)}
        assert not prop_quant.propagation_path

        for node_key, node in mock_qp_graph.nodes.items():
            if node_key == target_node_key:
                assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            elif node_key == target_ip_node_key:
                assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
            else:
                assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

            target_ip_node = mock_qp_graph.nodes[target_ip_node_key]
            assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

        for from_node, to_node, edge_data in mock_qp_graph.edges.data():
            if (from_node, to_node) == (target_ip_node_key, target_node_key):
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            else:
                assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        with pytest.raises(RuntimeError):
            _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list,
                                                        InsertionPointGraph.get_post_hook_node_key(target_node_key))
Exemplo n.º 2
0
def test_quantization_configs__custom():
    model = BasicConvTestModel()

    config = get_basic_quantization_config()
    config['compression'].update({
        "weights": {
            "mode": "asymmetric",
            "per_channel": True,
            "signed": False,
            "bits": 4
        },
        "activations": {
            "mode": "asymmetric",
            "bits": 4,
            "signed": True,
        },
    })
    reset_context('orig')
    reset_context('quantized_graphs')
    compression_algo = create_compression_algorithm(model, config)

    weight_quantizers, activation_quantizers = split_quantizers(
        compression_algo.model)

    ref_weight_qconfig = QuantizerConfig(4, QuantizationMode.ASYMMETRIC, None,
                                         True, None, True)
    for wq in weight_quantizers:
        compare_qconfigs(ref_weight_qconfig, wq.config)

    ref_activation_qconfig = QuantizerConfig(4, QuantizationMode.ASYMMETRIC,
                                             True, False, None, False)
    for wq in activation_quantizers:
        compare_qconfigs(ref_activation_qconfig, wq.config)
    def test_remove_propagating_quantizer(self, mock_qp_graph, start_target_nodes_for_two_quantizers):
        start_ip_node_key_remove = start_target_nodes_for_two_quantizers[0]
        target_ip_node_key_remove = start_target_nodes_for_two_quantizers[1]

        start_ip_node_key_keep = start_target_nodes_for_two_quantizers[2]
        target_ip_node_key_keep = start_target_nodes_for_two_quantizers[3]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path_remove = get_edge_paths_for_propagation(mock_qp_graph,
                                                         target_ip_node_key_remove,
                                                         start_ip_node_key_remove)[0]
        rev_path_keep = get_edge_paths_for_propagation(mock_qp_graph,
                                                       target_ip_node_key_keep,
                                                       start_ip_node_key_keep)[0]

        prop_quant_to_remove = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                       start_ip_node_key_remove)
        prop_quant_to_remove = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_remove, rev_path_remove)

        prop_quant_to_keep = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key_keep)

        prop_quant_to_keep = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_keep, rev_path_keep)

        affected_ip_nodes = deepcopy(prop_quant_to_remove.affected_ip_nodes)
        affected_op_nodes = deepcopy(prop_quant_to_remove.affected_operator_nodes)
        affected_edges = deepcopy(prop_quant_to_keep.affected_edges)
        last_location = prop_quant_to_remove.current_location_node_key
        ref_quant_to_keep_state_dict = deepcopy(prop_quant_to_keep.__dict__)

        mock_qp_graph.remove_propagating_quantizer(prop_quant_to_remove)

        last_node = mock_qp_graph.nodes[last_location]
        assert last_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

        for ip_node_key in affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for op_node_key in affected_op_nodes:
            node = mock_qp_graph.nodes[op_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_remove not in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        assert prop_quant_to_keep.__dict__ == ref_quant_to_keep_state_dict

        for ip_node_key in prop_quant_to_keep.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_keep in node[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in prop_quant_to_keep.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_keep in edge[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('C'))
     pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('D'))
     qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig()],
                                              [None, None],
                                              InsertionPointGraph.get_post_hook_node_key('B'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('E'))
     return qpsg
        def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
            pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                                  InsertionPointGraph.get_pre_hook_node_key('F'))

            qpsg.propagate_quantizer_via_path(pq_1, [
                (InsertionPointGraph.get_post_hook_node_key('C'),
                 InsertionPointGraph.get_pre_hook_node_key('F'))
            ])
            _ = qpsg.add_propagating_quantizer([QuantizerConfig(bits=6)],
                                               InsertionPointGraph.get_pre_hook_node_key('F'))
            return qpsg
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     # This case will fail if, after going depth-first through the 'D' branch of the graph,
     # the merge traversal function state is not reset (which is incorrect behavior)
     # when starting to traverse the 'C' branch.
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=0))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=1))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('C'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('D'))
     return qpsg
Exemplo n.º 7
0
def test_quantize_range_init_sets_correct_scale_shapes(
        quantizer_range_init_test_struct: Tuple[QRISSTS, str]):
    test_struct = quantizer_range_init_test_struct[0]
    initializer_type = quantizer_range_init_test_struct[1]
    for quantization_mode in [
            QuantizationMode.SYMMETRIC, QuantizationMode.ASYMMETRIC
    ]:
        qconfig = QuantizerConfig(mode=quantization_mode,
                                  per_channel=test_struct.per_channel,
                                  is_weights=test_struct.is_weights,
                                  input_shape=test_struct.input_shape)
        q_cls = QUANTIZATION_MODULES.get(quantization_mode)
        quantizer = q_cls(qconfig)  # type: BaseQuantizer
        range_init_config = RangeInitConfig(init_type=initializer_type,
                                            num_init_samples=1)
        collector = range_init_config.generate_stat_collector(
            reduction_shapes={tuple(quantizer.scale_shape)})
        collector.register_input(torch.ones(test_struct.input_shape))
        stat = collector.get_statistics()[tuple(quantizer.scale_shape)]
        minmax_values = MinMaxTensorStatistic.from_stat(stat)
        quantizer.apply_minmax_init(min_values=minmax_values.min_values,
                                    max_values=minmax_values.max_values)

        assert quantizer.scale_shape == test_struct.ref_scale_shape
        if quantization_mode == QuantizationMode.SYMMETRIC:
            assert list(quantizer.scale.shape) == test_struct.ref_scale_shape
        elif quantization_mode == QuantizationMode.ASYMMETRIC:
            assert list(
                quantizer.input_low.shape) == test_struct.ref_scale_shape
            assert list(
                quantizer.input_range.shape) == test_struct.ref_scale_shape
        else:
            assert False  # options above should be exhaustive
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                           InsertionPointGraph.get_pre_hook_node_key('C'))
     pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)],
                                           InsertionPointGraph.get_pre_hook_node_key('D'))
     qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig(per_channel=True)],
                                              [None, None],
                                              InsertionPointGraph.get_post_hook_node_key('B'))
     pq_3 = qpsg.add_propagating_quantizer([QuantizerConfig()],
                                           InsertionPointGraph.get_pre_hook_node_key('E'))
     paths = get_edge_paths_for_propagation(qpsg,
                                            InsertionPointGraph.get_pre_hook_node_key('D'),
                                            InsertionPointGraph.get_pre_hook_node_key('E'))
     path = paths[0]
     qpsg.propagate_quantizer_via_path(pq_3, path)
     return qpsg
    def test_clone_propagating_quantizer(self, mock_qp_graph, start_target_nodes):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]

        ref_prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                 start_ip_node_key)

        prop_quant = mock_qp_graph.propagate_quantizer_via_path(ref_prop_quant, rev_path)

        cloned_prop_quant = mock_qp_graph.clone_propagating_quantizer(prop_quant)

        assert cloned_prop_quant.affected_ip_nodes == prop_quant.affected_ip_nodes
        assert cloned_prop_quant.affected_edges == prop_quant.affected_edges
        assert cloned_prop_quant.propagation_path == prop_quant.propagation_path
        assert cloned_prop_quant.current_location_node_key == prop_quant.current_location_node_key

        for ip_node_key in prop_quant.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert cloned_prop_quant in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
        for from_node_key, to_node_key in prop_quant.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert cloned_prop_quant in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        # The cloned quantizer had not been put into any IP (cannot have multiple PQs in one IP right now)
        mock_qp_graph.skip_check = True
    def test_backtrack_propagation_until_accepting_location(self, start_target_accepting_nodes, mock_qp_graph):
        start_ip_node_key = start_target_accepting_nodes[0]
        target_ip_node_key = start_target_accepting_nodes[1]
        forced_last_accepting_location = start_target_accepting_nodes[2]

        prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                             start_ip_node_key)
        ref_affected_edges = deepcopy(prop_quant.affected_edges)

        # Here, the tested graph should have such a structure that there is only one path from target to start
        path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]
        prop_quant = mock_qp_graph.propagate_quantizer_via_path(prop_quant, path)
        prop_quant.last_accepting_location_node_key = forced_last_accepting_location
        if forced_last_accepting_location is not None:
            resulting_path = get_edge_paths_for_propagation(mock_qp_graph,
                                                            forced_last_accepting_location, start_ip_node_key)[0]
            ref_affected_edges.update(set(resulting_path))

        prop_quant = mock_qp_graph.backtrack_propagation_until_accepting_location(prop_quant)

        assert prop_quant.current_location_node_key == forced_last_accepting_location
        assert prop_quant.affected_edges == ref_affected_edges
        assert prop_quant.propagation_path == resulting_path

        target_node = mock_qp_graph.nodes[target_ip_node_key]
        accepting_node = mock_qp_graph.nodes[forced_last_accepting_location]
        if forced_last_accepting_location != target_ip_node_key:
            assert target_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
            assert target_ip_node_key not in prop_quant.affected_ip_nodes
        assert accepting_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
    def test_quantizers_can_propagate_via_path(self, start_target_nodes, mock_qp_graph):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        rev_paths = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)

        for path in rev_paths:
            working_graph = deepcopy(mock_qp_graph)
            ref_prop_quant = working_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key)
            ref_affected_edges = deepcopy(ref_prop_quant.affected_edges)
            ref_affected_edges.update(set(path))
            ref_affected_ip_nodes = deepcopy(ref_prop_quant.affected_ip_nodes)
            prop_quant = working_graph.propagate_quantizer_via_path(ref_prop_quant, path)
            final_node_key, _ = path[-1]
            for from_node_key, to_node_key in path:
                edge_data = working_graph.edges[from_node_key, to_node_key]
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [ref_prop_quant]
                to_node = working_graph.nodes[to_node_key]
                if to_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert to_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                from_node = working_graph.nodes[from_node_key]
                if from_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    ref_affected_ip_nodes.add(from_node_key)
            working_graph.run_consistency_check()

            final_node_key, _ = path[-1]
            final_node = working_graph.nodes[final_node_key]
            assert final_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == ref_prop_quant

            assert prop_quant.current_location_node_key == final_node_key
            assert prop_quant.propagation_path == path
            assert prop_quant.affected_edges == ref_affected_edges
            assert prop_quant.affected_ip_nodes == ref_affected_ip_nodes
Exemplo n.º 12
0
def test_quantization_configs__with_defaults():
    model = BasicConvTestModel()
    config = get_basic_quantization_config()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

    assert isinstance(compression_ctrl, QuantizationController)
    weight_quantizers = compression_ctrl.weight_quantizers
    activation_quantizers = compression_ctrl.non_weight_quantizers

    ref_weight_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, True)
    for wq in weight_quantizers.values():
        compare_qconfigs(ref_weight_qconfig, wq)

    ref_activation_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, False)
    for wq in activation_quantizers.values():
        compare_qconfigs(ref_activation_qconfig, wq)
Exemplo n.º 13
0
def test_quantize_range_init_sets_correct_scale_shapes(
        quantizer_range_init_test_struct: Tuple[QRISSTS, str]):
    test_struct = quantizer_range_init_test_struct[0]
    initializer_type = quantizer_range_init_test_struct[1]
    for quantization_mode in [
            QuantizationMode.SYMMETRIC, QuantizationMode.ASYMMETRIC
    ]:
        qconfig = QuantizerConfig(mode=quantization_mode,
                                  per_channel=test_struct.per_channel,
                                  is_weights=test_struct.is_weights,
                                  input_shape=test_struct.input_shape)
        q_cls = QUANTIZATION_MODULES.get(quantization_mode)
        quantizer = q_cls(qconfig)  # type: BaseQuantizer
        init_config = {"type": initializer_type, "num_init_steps": 1}
        initializer = RangeInitializerFactory.create(init_config, quantizer,
                                                     "")
        initializer.register_input(torch.ones(test_struct.input_shape))

        with torch.no_grad():
            initializer.apply_init()

        assert quantizer.scale_shape == test_struct.ref_scale_shape
        if quantization_mode == QuantizationMode.SYMMETRIC:
            assert list(quantizer.scale.shape) == test_struct.ref_scale_shape
        elif quantization_mode == QuantizationMode.ASYMMETRIC:
            assert list(
                quantizer.input_low.shape) == test_struct.ref_scale_shape
            assert list(
                quantizer.input_range.shape) == test_struct.ref_scale_shape
        else:
            assert False  # options above should be exhaustive
Exemplo n.º 14
0
    def get_qconf_from_hw_config_subdict(quantization_subdict: Dict, for_weights=False):
        bits = quantization_subdict["bits"]
        mode = HWConfig.get_quantization_mode_from_config_value(quantization_subdict["mode"])
        is_per_channel = HWConfig.get_is_per_channel_from_config_value(quantization_subdict["granularity"])
        signedness_to_force = None
        if 'level_low' in quantization_subdict and 'level_high' in quantization_subdict:
            signedness_to_force = False
            if mode == QuantizationMode.SYMMETRIC:
                if quantization_subdict['level_low'] < 0 < quantization_subdict['level_high']:
                    signedness_to_force = True
                true_level_low, true_level_high, _ = SymmetricQuantizer.calculate_level_ranges(bits, True)
            else:
                signedness_to_force = True
                true_level_low, true_level_high, _ = AsymmetricQuantizer.calculate_level_ranges(bits)

            assert quantization_subdict['level_low'] == true_level_low, \
                    "Invalid value of quantizer parameter `level_low`.\
                         The parameter must be consistent with other parameters!"
            assert quantization_subdict['level_high'] == true_level_high, \
                    "Invalid value of quantizer parameter `level_high`.\
                         The parameter must be consistent with other parameters!"

        return QuantizerConfig(bits=bits,
                               mode=mode,
                               per_channel=is_per_channel,
                               signedness_to_force=signedness_to_force,
                               is_weights=for_weights)
Exemplo n.º 15
0
def test_check_hawq_dump(mocker, tmp_path):
    tensor1 = torch.Tensor([1])
    tensor2 = torch.Tensor([2])
    qconf1 = QuantizerConfig(bits=2)
    qconf2 = QuantizerConfig(bits=4)
    id_ = 0
    quantizer_configurations = [[qconf1, qconf1], [qconf2, qconf2]]
    flops_per_config = [tensor1.item(), tensor2.item()]
    choosen_config_index = id_
    configuration_metric = [tensor1, tensor2]
    perturbations = Perturbations()
    perturbations.add(id_, qconf1, tensor1)
    perturbations.add(id_, qconf2, tensor2)
    perturbations.add(id_ + 1, qconf1, tensor2)
    perturbations.add(id_ + 1, qconf2, tensor1)

    observer1 = PerturbationObserver(mocker.stub())
    observer1.perturbation = tensor1
    observer1.numels = id_
    observer1.input_norm = id_

    observer2 = PerturbationObserver(mocker.stub())
    observer2.perturbation = tensor2
    observer2.numels = id_
    observer2.input_norm = id_
    weight_observers = [observer1, observer2]
    traces_per_layer = TracesPerLayer(torch.cat((tensor1, tensor2)))

    set_debug_log_dir(str(tmp_path))
    hawq_debugger = HAWQDebugger(quantizer_configurations, perturbations,
                                 quantizer_configurations,
                                 [weight_observers, weight_observers],
                                 traces_per_layer, [qconf1.bits, qconf2.bits])

    hawq_debugger.dump_metric_MB(configuration_metric)
    hawq_debugger.dump_metric_flops(configuration_metric, flops_per_config,
                                    choosen_config_index)
    hawq_debugger.dump_avg_traces()
    hawq_debugger.dump_density_of_quantization_noise()
    hawq_debugger.dump_perturbations_ratio()
    test_dir = tmp_path / Path('hawq_dumps')
    num_dump_files = len([
        name for name in os.listdir(test_dir)
        if os.path.isfile(os.path.join(test_dir, name))
    ])
    assert num_dump_files == 6
Exemplo n.º 16
0
    def validate_spy(self):
        bitwidth_list = self.set_chosen_config_spy.call_args[0][1]
        assert len(bitwidth_list) == self.n_weight_quantizers
        # with default compression ratio = 1.5 all precisions should be different from the default one
        assert set(bitwidth_list) != {QuantizerConfig().bits}

        init_data_loader = self.hessian_trace_estimator_spy.call_args[0][5]
        expected_batch_size = self.batch_size_init if self.batch_size_init else self.batch_size
        assert init_data_loader.batch_size == expected_batch_size
 def get_qconf_from_hw_config_subdict(quantization_subdict: Dict):
     bits = quantization_subdict["bits"]
     mode = HWConfig.get_quantization_mode_from_config_value(
         quantization_subdict["mode"])
     is_per_channel = HWConfig.get_is_per_channel_from_config_value(
         quantization_subdict["granularity"])
     return QuantizerConfig(bits=bits,
                            mode=mode,
                            per_channel=is_per_channel)
Exemplo n.º 18
0
def test_quantization_configs__with_defaults():
    model = BasicConvTestModel()
    config = get_basic_quantization_config()
    reset_context('orig')
    reset_context('quantized_graphs')
    compression_algo = create_compression_algorithm(deepcopy(model), config)

    weight_quantizers, activation_quantizers = split_quantizers(
        compression_algo.model)

    ref_weight_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None,
                                         False, None, True)
    for wq in weight_quantizers:
        compare_qconfigs(ref_weight_qconfig, wq.config)

    ref_activation_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC,
                                             None, False, None, False)
    for wq in activation_quantizers:
        compare_qconfigs(ref_activation_qconfig, wq.config)
Exemplo n.º 19
0
def test_quantization_configs__custom():
    model = BasicConvTestModel()

    config = get_quantization_config_without_range_init()
    config['compression'].update({
        "weights": {
            "mode": "asymmetric",
            "per_channel": True,
            "bits": 4
        },
        "activations": {
            "mode": "asymmetric",
            "bits": 4,
            "signed": True,
        },
    })
    config['target_device'] = 'NONE'
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    assert isinstance(compression_ctrl, QuantizationController)
    weight_quantizers = compression_ctrl.weight_quantizers
    activation_quantizer_infos = compression_ctrl.non_weight_quantizers

    ref_weight_qconfig = QuantizerConfig(bits=4,
                                         mode=QuantizationMode.ASYMMETRIC,
                                         signedness_to_force=None,
                                         per_channel=True,
                                         input_shape=None,
                                         is_weights=True)
    for wq in weight_quantizers.values():
        compare_qconfigs(ref_weight_qconfig, wq)

    ref_activation_qconfig = QuantizerConfig(bits=4,
                                             mode=QuantizationMode.ASYMMETRIC,
                                             signedness_to_force=True,
                                             per_channel=False,
                                             input_shape=None,
                                             is_weights=False)

    for aq_info in activation_quantizer_infos.values():
        compare_qconfigs(ref_activation_qconfig, aq_info.quantizer_module_ref)
Exemplo n.º 20
0
def generate_qp(scope_str: str,
                target: QuantizerGroup,
                in_port_id: int = None) -> SingleConfigQuantizationPoint:
    if target is QuantizerGroup.WEIGHTS:
        ip = InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP,
                            module_scope=Scope.from_str(scope_str))
    elif target is QuantizerGroup.ACTIVATIONS:
        ip = InsertionPoint(
            InsertionType.OPERATOR_POST_HOOK
            if in_port_id is None else InsertionType.OPERATOR_PRE_HOOK,
            ia_op_exec_context=InputAgnosticOperationExecutionContext.from_str(
                scope_str),
            input_port_id=in_port_id)
    else:
        raise RuntimeError()
    return SingleConfigQuantizationPoint(ip, QuantizerConfig())
Exemplo n.º 21
0
def test_flops(config_creator, ref_values):
    class ConvLinear(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = create_conv(1, 1, 2, -1, -2)
            self.fc = nn.Linear(3, 6)

        def forward(self, x):
            return self.fc(self.conv1(x))

    config = config_creator()
    model, compression_ctrl = create_compressed_model_and_algo_for_test(ConvLinear(), config)
    quantizers = compression_ctrl.weight_quantizers

    handler = WeightQuantizersHandler(model, quantizers, HWPrecisionConstraints(True))
    flops_counter = CompressionRatioCalculator(model, handler)

    assert flops_counter.ratio_for_bits_configuration([4, 8]) == ref_values[0]
    assert flops_counter.ratio_for_bits_configuration([8, 4]) == ref_values[1]
    assert flops_counter.ratio_limits([4, 8]) == ref_values[2]
    assert flops_counter.ratio_limits([2, 4, 8]) == ref_values[3]
    constraints = HWPrecisionConstraints(True).add(list(quantizers)[0], [QuantizerConfig(bits=8)])
    assert flops_counter.ratio_limits([2, 8], constraints) == ref_values[4]

if __name__ == '__main__':
    for input_name, input_size, gpu_runs in TEST_PARAMS_STRUCT:
        print("CUDA " + input_name)
        print("------------------------------------------------")
        print("Pytorch Symmetric (cuda 0) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            ReferenceQuantize(NBITS).cuda(), input_size, 'cuda', gpu_runs)

        print()
        print("Custom Symmetric (cuda 0 ) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            SymmetricQuantizer(QuantizerConfig(
                QuantizationParams(bits=NBITS))).cuda(), input_size, 'cuda',
            gpu_runs)

        print()
        print("Pytorch Symmetric Per Weight Channel (cuda 0) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            ReferenceQuantize(NBITS,
                              input_shape=input_size,
                              per_channel=True,
                              is_weights=True).cuda(), input_size, 'cuda',
            gpu_runs)

        print()
        print("Custom Symmetric Per Weight Channel  (cuda 0 ) impl")
        print("input size: {0}".format(input_size))
Exemplo n.º 23
0
    # nx_graph is expected to have version-agnostic operator names already
    for k, attrs in nx_graph.nodes.items():
        attrs = {k: str(v) for k, v in attrs.items()}
        load_attrs = {k: str(v).strip('"') for k, v in load_graph.nodes[k].items()}
        assert attrs == load_attrs

    assert load_graph.nodes.keys() == nx_graph.nodes.keys()
    assert nx.DiGraph(load_graph).edges == nx_graph.edges


QuantizeConfig = namedtuple('QuantizeConfig', ['quantizer', 'graph_dir'])

QUANTIZERS = [
    QuantizeConfig(lambda _, is_weights=False, input_shape=None: SymmetricQuantizer(
        QuantizerConfig(signedness_to_force=is_weights, is_weights=is_weights, input_shape=input_shape)),
                   'symmetric'),
    QuantizeConfig(lambda _, is_weights, input_shape=None: AsymmetricQuantizer(QuantizerConfig()), 'asymmetric')
]


@pytest.fixture(scope='function', params=QUANTIZERS, ids=[pair.graph_dir for pair in QUANTIZERS])
def _quantize_config(request):
    config = request.param
    graph_dir = os.path.join('quantized', config.graph_dir)
    return QuantizeConfig(config.quantizer, graph_dir)


def default_forward_fn(model, input_size_):
    device = next(model.parameters()).device
    return model(torch.zeros(input_size_).to(device))
Exemplo n.º 24
0
def test_onnx_export_to_quantize_dequantize_per_channel():
    # SYMMETRIC
    q_config = QuantizerConfig(input_shape=(2, 64, 15, 10),
                               bits=8,
                               mode=QuantizationMode.SYMMETRIC,
                               signedness_to_force=None,
                               per_channel=True)
    sym_quantizer = SymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS

    x = torch.rand((2, 64, 15, 10))
    sym_quantizer.run_export_quantization(x)

    q_config = QuantizerConfig(bits=8,
                               mode=QuantizationMode.SYMMETRIC,
                               signedness_to_force=None,
                               per_channel=False)
    sym_quantizer = SymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS

    x = torch.rand((2, 64, 15, 10))
    sym_quantizer.run_export_quantization(x)

    q_config = QuantizerConfig(input_shape=(2, 64, 15, 10),
                               bits=8,
                               mode=QuantizationMode.SYMMETRIC,
                               signedness_to_force=None,
                               per_channel=True)
    sym_quantizer = SymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS
    sym_quantizer.scale = torch.nn.Parameter(torch.rand(1, 64, 1, 1))

    x = torch.rand((2, 64, 15, 10))
    try:
        sym_quantizer.run_export_quantization(x)
    except RuntimeError as e:
        assert str(
            e) == "PyTorch v1.5.0 export to ONNX using QuantizeLinear-DequantizeLinear " \
                  "doesn't support per channel quantization"
    # ASYMMETRIC
    q_config = QuantizerConfig(input_shape=(2, 64, 15, 10),
                               bits=8,
                               mode=QuantizationMode.ASYMMETRIC,
                               signedness_to_force=None,
                               per_channel=True)
    assym_quantizer = AsymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS

    x = torch.rand((2, 64, 15, 10))
    assym_quantizer.run_export_quantization(x)

    q_config = QuantizerConfig(bits=8,
                               mode=QuantizationMode.ASYMMETRIC,
                               signedness_to_force=None,
                               per_channel=False)
    assym_quantizer = AsymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS

    x = torch.rand((2, 64, 15, 10))
    assym_quantizer.run_export_quantization(x)

    q_config = QuantizerConfig(input_shape=(2, 64, 15, 10),
                               bits=8,
                               mode=QuantizationMode.ASYMMETRIC,
                               signedness_to_force=None,
                               per_channel=True)
    assym_quantizer = AsymmetricQuantizer(q_config)
    # pylint: disable=protected-access
    assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS
    sym_quantizer.scale = torch.nn.Parameter(torch.rand(1, 64, 1, 1))

    x = torch.rand((2, 64, 15, 10))
    try:
        assym_quantizer.run_export_quantization(x)
    except RuntimeError as e:
        assert str(
            e) == "PyTorch v1.5.0 export to ONNX using QuantizeLinear-DequantizeLinear" \
                  " doesn't support per channel quantization"
Exemplo n.º 25
0
def get_mock_precision_constraints(constraints, ordered_weight_keys):
    hw_precision_constraints = HWPrecisionConstraints(True)
    for key, bits in zip(ordered_weight_keys, constraints):
        bit_constraints = [QuantizerConfig(bits=bitwidth) for bitwidth in bits]
        hw_precision_constraints.add(key, bit_constraints)
    return hw_precision_constraints
 def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('B'))
     qpsg.add_propagating_quantizer([QuantizerConfig()],
                                    InsertionPointGraph.get_pre_hook_node_key('E'))
     return qpsg
class TestQuantizerPropagationStateGraph:
    #pylint:disable=too-many-public-methods
    @staticmethod
    @pytest.fixture()
    def mock_qp_graph():
        ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
        qpsg = QPSG(ip_graph)
        qpsg.skip_check = False
        yield qpsg
        if not qpsg.skip_check:
            qpsg.run_consistency_check()

    def test_build_quantizer_propagation_state_graph_from_ip_graph(self):
        ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph())
        quant_prop_graph = QPSG(ip_graph)
        assert len(ip_graph.nodes) == len(quant_prop_graph.nodes)
        assert len(ip_graph.edges) == len(quant_prop_graph.edges)

        for ip_graph_node_key, ip_graph_node in ip_graph.nodes.items():
            qpg_node = quant_prop_graph.nodes[ip_graph_node_key]
            assert qpg_node[QPSG.NODE_TYPE_NODE_ATTR] == QPSG.ipg_node_type_to_qpsg_node_type(ip_graph_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR])
            qpg_node_type = qpg_node[QPSG.NODE_TYPE_NODE_ATTR]
            if qpg_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                assert qpg_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                assert qpg_node[QPSG.INSERTION_POINT_DATA_NODE_ATTR] == ip_graph_node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif qpg_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                assert not qpg_node[QPSG.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR]
                assert qpg_node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] == QuantizationTrait.NON_QUANTIZABLE
                assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            qpg_edge_data = quant_prop_graph.edges[from_node, to_node]
            assert not qpg_edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
            for key, value in edge_data.items():
                assert qpg_edge_data[key] == value

        quant_prop_graph.run_consistency_check()

    def test_add_propagating_quantizer(self, mock_qp_graph):
        ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)]

        target_node_key = "F"
        target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key)
        prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key)
        assert prop_quant.potential_quant_configs == ref_qconf_list
        assert prop_quant.current_location_node_key == target_ip_node_key
        assert prop_quant.affected_ip_nodes == {target_ip_node_key}
        assert prop_quant.last_accepting_location_node_key is None
        assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)}
        assert not prop_quant.propagation_path

        for node_key, node in mock_qp_graph.nodes.items():
            if node_key == target_node_key:
                assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            elif node_key == target_ip_node_key:
                assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
            else:
                assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

            target_ip_node = mock_qp_graph.nodes[target_ip_node_key]
            assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

        for from_node, to_node, edge_data in mock_qp_graph.edges.data():
            if (from_node, to_node) == (target_ip_node_key, target_node_key):
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
            else:
                assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        with pytest.raises(RuntimeError):
            _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list,
                                                        InsertionPointGraph.get_post_hook_node_key(target_node_key))

    START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES = [
        # Non-branching case - starting from "E" pre-hook
        (InsertionPointGraph.get_pre_hook_node_key("E"),
         [[(InsertionPointGraph.get_post_hook_node_key("C"), InsertionPointGraph.get_pre_hook_node_key("E"))]]),

        # Non-branching case - starting from "C" post-hook
        (InsertionPointGraph.get_post_hook_node_key("C"),
         [[("C", InsertionPointGraph.get_post_hook_node_key("C")),
           (InsertionPointGraph.get_pre_hook_node_key("C"), "C")]]),

        # Branching case - starting from "F" pre-hook port 0
        (InsertionPointGraph.get_pre_hook_node_key("F"),
         [[(InsertionPointGraph.get_post_hook_node_key("D"), InsertionPointGraph.get_pre_hook_node_key("F"))]]),

        # Branching case - starting from "F" pre-hook port 1
        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         [[(InsertionPointGraph.get_post_hook_node_key("E"),
            InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1))]]),

    ]

    @staticmethod
    @pytest.fixture(params=START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES)
    def start_ip_node_and_path_to_dominating_node(request):
        return request.param

    def test_get_paths_to_immediately_dominating_insertion_points(self, start_ip_node_and_path_to_dominating_node,
                                                                  mock_qp_graph):
        start_node = start_ip_node_and_path_to_dominating_node[0]
        ref_paths = start_ip_node_and_path_to_dominating_node[1]
        test_paths = mock_qp_graph.get_paths_to_immediately_dominating_insertion_points(start_node)
        def get_cat_path_list(path_list: List[List[Tuple[str, str]]]):
            str_paths = [[str(edge[0]) +' -> ' + str(edge[1]) for edge in path] for path in path_list]
            cat_paths = [';'.join(path) for path in str_paths]
            return cat_paths

        assert Counter(get_cat_path_list(ref_paths)) == Counter(get_cat_path_list(test_paths))


    START_TARGET_NODES = [
        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("F")),

        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         InsertionPointGraph.get_pre_hook_node_key("E")),

        (InsertionPointGraph.get_pre_hook_node_key("F"),
         InsertionPointGraph.get_post_hook_node_key("B")),
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_NODES)
    def start_target_nodes(request):
        return request.param

    @pytest.mark.dependency(name="propagate_via_path")
    def test_quantizers_can_propagate_via_path(self, start_target_nodes, mock_qp_graph):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        rev_paths = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)

        for path in rev_paths:
            working_graph = deepcopy(mock_qp_graph)
            ref_prop_quant = working_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key)
            ref_affected_edges = deepcopy(ref_prop_quant.affected_edges)
            ref_affected_edges.update(set(path))
            ref_affected_ip_nodes = deepcopy(ref_prop_quant.affected_ip_nodes)
            prop_quant = working_graph.propagate_quantizer_via_path(ref_prop_quant, path)
            final_node_key, _ = path[-1]
            for from_node_key, to_node_key in path:
                edge_data = working_graph.edges[from_node_key, to_node_key]
                assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [ref_prop_quant]
                to_node = working_graph.nodes[to_node_key]
                if to_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    assert to_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
                from_node = working_graph.nodes[from_node_key]
                if from_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    ref_affected_ip_nodes.add(from_node_key)
            working_graph.run_consistency_check()

            final_node_key, _ = path[-1]
            final_node = working_graph.nodes[final_node_key]
            assert final_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == ref_prop_quant

            assert prop_quant.current_location_node_key == final_node_key
            assert prop_quant.propagation_path == path
            assert prop_quant.affected_edges == ref_affected_edges
            assert prop_quant.affected_ip_nodes == ref_affected_ip_nodes

    START_TARGET_ACCEPTING_NODES = [
        (InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("F"),
         InsertionPointGraph.get_post_hook_node_key("F")),

        (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1),
         InsertionPointGraph.get_pre_hook_node_key("C"),
         InsertionPointGraph.get_post_hook_node_key("C")),

        (InsertionPointGraph.get_pre_hook_node_key("D"),
         InsertionPointGraph.get_pre_hook_node_key("B"),
         InsertionPointGraph.get_post_hook_node_key("B")),
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_ACCEPTING_NODES)
    def start_target_accepting_nodes(request):
        return request.param

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_backtrack_propagation_until_accepting_location(self, start_target_accepting_nodes, mock_qp_graph):
        start_ip_node_key = start_target_accepting_nodes[0]
        target_ip_node_key = start_target_accepting_nodes[1]
        forced_last_accepting_location = start_target_accepting_nodes[2]

        prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                             start_ip_node_key)
        ref_affected_edges = deepcopy(prop_quant.affected_edges)

        # Here, the tested graph should have such a structure that there is only one path from target to start
        path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]
        prop_quant = mock_qp_graph.propagate_quantizer_via_path(prop_quant, path)
        prop_quant.last_accepting_location_node_key = forced_last_accepting_location
        if forced_last_accepting_location is not None:
            resulting_path = get_edge_paths_for_propagation(mock_qp_graph,
                                                            forced_last_accepting_location, start_ip_node_key)[0]
            ref_affected_edges.update(set(resulting_path))

        prop_quant = mock_qp_graph.backtrack_propagation_until_accepting_location(prop_quant)

        assert prop_quant.current_location_node_key == forced_last_accepting_location
        assert prop_quant.affected_edges == ref_affected_edges
        assert prop_quant.propagation_path == resulting_path

        target_node = mock_qp_graph.nodes[target_ip_node_key]
        accepting_node = mock_qp_graph.nodes[forced_last_accepting_location]
        if forced_last_accepting_location != target_ip_node_key:
            assert target_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None
            assert target_ip_node_key not in prop_quant.affected_ip_nodes
        assert accepting_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_clone_propagating_quantizer(self, mock_qp_graph, start_target_nodes):
        start_ip_node_key = start_target_nodes[0]
        target_ip_node_key = start_target_nodes[1]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0]

        ref_prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                 start_ip_node_key)

        prop_quant = mock_qp_graph.propagate_quantizer_via_path(ref_prop_quant, rev_path)

        cloned_prop_quant = mock_qp_graph.clone_propagating_quantizer(prop_quant)

        assert cloned_prop_quant.affected_ip_nodes == prop_quant.affected_ip_nodes
        assert cloned_prop_quant.affected_edges == prop_quant.affected_edges
        assert cloned_prop_quant.propagation_path == prop_quant.propagation_path
        assert cloned_prop_quant.current_location_node_key == prop_quant.current_location_node_key

        for ip_node_key in prop_quant.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert cloned_prop_quant in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
        for from_node_key, to_node_key in prop_quant.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert cloned_prop_quant in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        # The cloned quantizer had not been put into any IP (cannot have multiple PQs in one IP right now)
        mock_qp_graph.skip_check = True

    START_TARGET_NODES_FOR_TWO_QUANTIZERS = [
        (InsertionPointGraph.get_pre_hook_node_key("E"),
         InsertionPointGraph.get_post_hook_node_key("C"),
         InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_post_hook_node_key("G")),

        (InsertionPointGraph.get_pre_hook_node_key("C"),
         InsertionPointGraph.get_post_hook_node_key("A"),
         InsertionPointGraph.get_pre_hook_node_key("H"),
         InsertionPointGraph.get_pre_hook_node_key("D")),

        # Simulated quantizer merging result
        (InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_pre_hook_node_key("E"),
         InsertionPointGraph.get_pre_hook_node_key("G"),
         InsertionPointGraph.get_post_hook_node_key("D"))
    ]

    @staticmethod
    @pytest.fixture(params=START_TARGET_NODES_FOR_TWO_QUANTIZERS)
    def start_target_nodes_for_two_quantizers(request):
        return request.param

    @pytest.mark.dependency(depends="propagate_via_path")
    def test_remove_propagating_quantizer(self, mock_qp_graph, start_target_nodes_for_two_quantizers):
        start_ip_node_key_remove = start_target_nodes_for_two_quantizers[0]
        target_ip_node_key_remove = start_target_nodes_for_two_quantizers[1]

        start_ip_node_key_keep = start_target_nodes_for_two_quantizers[2]
        target_ip_node_key_keep = start_target_nodes_for_two_quantizers[3]

        # From "target" to "start" since propagation direction is inverse to edge direction
        # Only take one path out of possible paths for this test
        rev_path_remove = get_edge_paths_for_propagation(mock_qp_graph,
                                                         target_ip_node_key_remove,
                                                         start_ip_node_key_remove)[0]
        rev_path_keep = get_edge_paths_for_propagation(mock_qp_graph,
                                                       target_ip_node_key_keep,
                                                       start_ip_node_key_keep)[0]

        prop_quant_to_remove = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                       start_ip_node_key_remove)
        prop_quant_to_remove = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_remove, rev_path_remove)

        prop_quant_to_keep = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()],
                                                                     start_ip_node_key_keep)

        prop_quant_to_keep = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_keep, rev_path_keep)

        affected_ip_nodes = deepcopy(prop_quant_to_remove.affected_ip_nodes)
        affected_op_nodes = deepcopy(prop_quant_to_remove.affected_operator_nodes)
        affected_edges = deepcopy(prop_quant_to_keep.affected_edges)
        last_location = prop_quant_to_remove.current_location_node_key
        ref_quant_to_keep_state_dict = deepcopy(prop_quant_to_keep.__dict__)

        mock_qp_graph.remove_propagating_quantizer(prop_quant_to_remove)

        last_node = mock_qp_graph.nodes[last_location]
        assert last_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None

        for ip_node_key in affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for op_node_key in affected_op_nodes:
            node = mock_qp_graph.nodes[op_node_key]
            assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_remove not in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        assert prop_quant_to_keep.__dict__ == ref_quant_to_keep_state_dict

        for ip_node_key in prop_quant_to_keep.affected_ip_nodes:
            node = mock_qp_graph.nodes[ip_node_key]
            assert prop_quant_to_keep in node[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

        for from_node_key, to_node_key in prop_quant_to_keep.affected_edges:
            edge = mock_qp_graph.edges[from_node_key, to_node_key]
            assert prop_quant_to_keep in edge[
                QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

    QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES = [
        (["D", "E", "F"],
         {
             "B": {"D", "E"},
             InsertionPointGraph.get_pre_hook_node_key("B"): {"D", "E"},
             "E": {"F"},
             InsertionPointGraph.get_post_hook_node_key("D"): {"F"},
             "A": {"D", "E"},
             InsertionPointGraph.get_pre_hook_node_key("G"): set()
         }),
        (["C", "E", "H"],
         {
             InsertionPointGraph.get_pre_hook_node_key("C"): {"C"},
             InsertionPointGraph.get_post_hook_node_key("C"): {"E"},
             "D": {"H"},
             InsertionPointGraph.get_pre_hook_node_key("B"): {"C", "H"}, # corner case - has a branch without quantizers
             InsertionPointGraph.get_post_hook_node_key("H"): set()
         })
    ]

    @staticmethod
    @pytest.fixture(params=QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES)
    def dominated_nodes_test_struct(request):
        return request.param

    @staticmethod
    def mark_nodes_with_traits(qpsg: QPSG, node_keys_vs_traits_dict: Dict[str, QuantizationTrait]) -> QPSG:
        for node_key, node in qpsg.nodes.items():
            if node_key in node_keys_vs_traits_dict:
                node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] = node_keys_vs_traits_dict[node_key]
        return qpsg

    def test_get_quantizable_op_nodes_immediately_dominated_by_node(self, mock_qp_graph, dominated_nodes_test_struct):
        nodes_to_mark_as_quantizable = dominated_nodes_test_struct[0]

        node_keys_vs_trait_dict = {}
        for node_key in mock_qp_graph.nodes:
            node_keys_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC

        traits_to_mark_with = [QuantizationTrait.INPUTS_QUANTIZABLE,
                               QuantizationTrait.NON_QUANTIZABLE]

        for trait in traits_to_mark_with:
            for node_key in nodes_to_mark_as_quantizable:
                node_keys_vs_trait_dict[node_key] = trait

            mock_qp_graph = self.mark_nodes_with_traits(mock_qp_graph, node_keys_vs_trait_dict)
            for start_node_key, ref_dominated_quantizable_nodes_set in dominated_nodes_test_struct[1].items():
                dominated_quantizable_nodes_list = \
                    mock_qp_graph.get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(start_node_key)
                assert set(dominated_quantizable_nodes_list) == ref_dominated_quantizable_nodes_set

    @staticmethod
    def get_model_graph():
        mock_node_attrs = get_mock_nncf_node_attrs()
        mock_graph = nx.DiGraph()

        #     (A)
        #      |
        #     (B)
        #   /     \
        # (C)     (D)
        #  |       |
        # (F)     (E)
        #
        #

        node_keys = ['A', 'B', 'C', 'D', 'E', 'F']
        for node_key in node_keys:
            mock_graph.add_node(node_key, **mock_node_attrs)

        mock_graph.add_edges_from([('A', 'B'), ('B', 'C'), ('B', 'D'), ('D', 'E'), ('C', 'F')])
        mark_input_ports_lexicographically_based_on_input_node_key(mock_graph)
        return mock_graph


    StateQuantizerTestStruct = namedtuple('StateQuantizerTestStruct',
                                          ('init_node_to_trait_and_configs_dict',
                                           'starting_quantizer_ip_node',
                                           'target_node_for_quantizer',
                                           'is_merged',
                                           'prop_path'))

    SetQuantizersTestStruct = namedtuple('SetQuantizersTestStruct',
                                         ('start_set_quantizers',
                                          'expected_set_quantizers'))

    MERGE_QUANTIZER_INTO_PATH_TEST_CASES = [
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'),
                    is_merged=True,
                    prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'),
                                InsertionPointGraph.get_pre_hook_node_key('C'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'PRE HOOK B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                                       [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        ),
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=True,
                    prop_path=[('B', InsertionPointGraph.get_post_hook_node_key('B'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        ),
        SetQuantizersTestStruct(
            start_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'E': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'),
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None

                ),
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'F': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()]),
                    },
                    starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'),
                    target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'),
                    is_merged=True,
                    prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'),
                                InsertionPointGraph.get_pre_hook_node_key('C'))]
                )

            ],
            expected_set_quantizers=[
                StateQuantizerTestStruct(
                    init_node_to_trait_and_configs_dict=
                    {
                        'B': (QuantizationTrait.INPUTS_QUANTIZABLE,
                              [QuantizerConfig()])

                    },
                    starting_quantizer_ip_node=['E', 'F'],
                    target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'),
                    is_merged=False,
                    prop_path=None
                )
            ]
        )
    ]

    @staticmethod
    @pytest.fixture(params=MERGE_QUANTIZER_INTO_PATH_TEST_CASES)
    def merge_quantizer_into_path_test_struct(request):
        return request.param

    @pytest.fixture
    def model_graph_qpsg(self):
        mock_graph = self.get_model_graph()
        ip_graph = InsertionPointGraph(mock_graph)
        quant_prop_graph = QPSG(ip_graph)
        return quant_prop_graph

    def test_merge_quantizer_into_path(self, model_graph_qpsg, merge_quantizer_into_path_test_struct):
        quant_prop_graph = model_graph_qpsg

        for quantizers_test_struct in merge_quantizer_into_path_test_struct.start_set_quantizers:
            init_node_to_trait_and_configs_dict = quantizers_test_struct.init_node_to_trait_and_configs_dict
            starting_quantizer_ip_node = quantizers_test_struct.starting_quantizer_ip_node
            target_node = quantizers_test_struct.target_node_for_quantizer
            is_merged = quantizers_test_struct.is_merged
            prop_path = quantizers_test_struct.prop_path
            node_key_vs_trait_dict = {}  # type: Dict[str, QuantizationTrait]
            for node_key in quant_prop_graph.nodes:
                node_key_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC
            primary_prop_quant = None
            merged_prop_quant = []
            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                node_key_vs_trait_dict[node_key] = trait
            quant_prop_graph = self.mark_nodes_with_traits(quant_prop_graph, node_key_vs_trait_dict)

            for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items():
                trait = trait_and_configs_tuple[0]
                qconfigs = trait_and_configs_tuple[1]
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key)
                    prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs,
                                                                            ip_node_key)
                    if ip_node_key == starting_quantizer_ip_node:
                        primary_prop_quant = prop_quant

            path = get_edge_paths_for_propagation(quant_prop_graph,
                                                  target_node,
                                                  starting_quantizer_ip_node)
            primary_prop_quant = quant_prop_graph.propagate_quantizer_via_path(primary_prop_quant,
                                                                               path[0])
            if is_merged:
                merged_prop_quant.append((primary_prop_quant, prop_path))

            quant_prop_graph.run_consistency_check()

        for prop_quant, prop_path in merged_prop_quant:
            quant_prop_graph.merge_quantizer_into_path(prop_quant, prop_path)
            quant_prop_graph.run_consistency_check()

        expected_quantizers_test_struct = merge_quantizer_into_path_test_struct.expected_set_quantizers
        self.check_final_state_qpsg(quant_prop_graph, expected_quantizers_test_struct)

    @staticmethod
    def check_final_state_qpsg(final_quant_prop_graph, expected_quantizers_test_struct):
        for quantizer_param in expected_quantizers_test_struct:
            from_node_key = quantizer_param.target_node_for_quantizer
            expected_prop_path = set()
            target_node = quantizer_param.target_node_for_quantizer
            for start_node in quantizer_param.starting_quantizer_ip_node:
                added_path = get_edge_paths_for_propagation(final_quant_prop_graph,
                                                            target_node,
                                                            start_node)
                expected_prop_path.update(added_path[0])

            quantizer = final_quant_prop_graph.nodes[from_node_key][QPSG.PROPAGATING_QUANTIZER_NODE_ATTR]

            assert quantizer is not None

            for from_node_key, to_node_key in expected_prop_path:
                assert quantizer in final_quant_prop_graph.edges[(from_node_key, to_node_key)][
                    QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                from_node = final_quant_prop_graph.nodes[from_node_key]
                from_node_type = from_node[QPSG.NODE_TYPE_NODE_ATTR]
                if from_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                    # pylint:disable=line-too-long
                    assert quantizer in final_quant_prop_graph.nodes[from_node_key][QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

            assert quantizer.affected_edges == expected_prop_path

if __name__ == '__main__':
    for input_name, input_size, gpu_runs in TEST_PARAMS_STRUCT:
        print("CUDA " + input_name)
        print("------------------------------------------------")
        print("Pytorch Symmetric (cuda 0) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            ReferenceQuantize(NBITS).cuda(), input_size, 'cuda', gpu_runs)

        print()
        print("Custom Symmetric (cuda 0 ) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            SymmetricQuantizer(QuantizerConfig(bits=NBITS)).cuda(), input_size,
            'cuda', gpu_runs)

        print()
        print("Pytorch Symmetric Per Weight Channel (cuda 0) impl:")
        print("input size: {0}".format(input_size))
        run_profile(
            ReferenceQuantize(NBITS,
                              input_shape=input_size,
                              per_channel=True,
                              is_weights=True).cuda(), input_size, 'cuda',
            gpu_runs)

        print()
        print("Custom Symmetric Per Weight Channel  (cuda 0 ) impl")
        print("input size: {0}".format(input_size))
def create_quantize_module(_, is_weights=False, input_shape=None):
    return SymmetricQuantizer(QuantizerConfig(signedness_to_force=is_weights, is_weights=is_weights))
Exemplo n.º 30
0
class QuantizerPropagationSolver:
    """Analyzes a fresh QuantizerPropagationStateGraph object according to HW
       configuration supplied in the initializer and produces the list of insertion
       commands that correspond to the final state of the quantizer propagation graph
       when the model has the most contol flow graph edges quantized according to HW
       capabilities."""

    DEFAULT_QUANTIZATION_TYPES = [
        QuantizerConfig(bits=8,
                        mode=QuantizationMode.SYMMETRIC,
                        signedness_to_force=None,
                        per_channel=False)
    ]

    def __init__(
        self,
        ignored_scopes=None,
        hw_config=None,
        debug_interface: 'QuantizationDebugInterface' = None,
        propagation_strategy: PropagationStrategy = PropagationStrategy.
        AGGRESSIVE):
        self._hw_config = hw_config
        self._debug_interface = debug_interface
        self._propagation_strategy = propagation_strategy  # TODO: determine from config
        self._operator_quantization_trait_map = self.get_operator_quantization_traits_map(
        )
        self._operator_allowed_qconfigs_map = self._get_operator_qconfigs_map()
        self._active_propagating_quantizers_queue = deque()
        self._finished_propagating_quantizers = [
        ]  # type: List[PropagatingQuantizer]
        self._ignored_scopes = ignored_scopes

    def run_on_ip_graph(
        self, ip_graph: InsertionPointGraph
    ) -> Dict[InsertionInfo, Optional[List[QuantizerConfig]]]:
        """ The main function to be used on an InsertionPointGraph to produce
            the list of insertion commands and configs corresponding to the final quantized
            graph state."""
        quant_prop_graph = QuantizerPropagationStateGraph(
            ip_graph, self._ignored_scopes)
        quant_prop_graph = self.set_allowed_quantization_types_for_operator_nodes(
            quant_prop_graph)
        quant_prop_graph = self.setup_initial_quantizers(quant_prop_graph)
        iteration_counter = 0
        while self._active_propagating_quantizers_queue:
            prop_quantizer = self._active_propagating_quantizers_queue.pop()
            if self._debug_interface is not None:
                self._debug_interface.visualize_quantizer_propagation(
                    self, quant_prop_graph, str(iteration_counter))
            quant_prop_graph = self.propagation_step(prop_quantizer,
                                                     quant_prop_graph)
            iteration_counter += 1

        if self._debug_interface is not None:
            self._debug_interface.visualize_quantizer_propagation(
                self, quant_prop_graph, "final")

        retval = {}
        self._potential_quantizers = {}
        for finished_prop_quantizer in self._finished_propagating_quantizers:
            final_node_key = finished_prop_quantizer.current_location_node_key
            final_node = quant_prop_graph.nodes[final_node_key]
            insertion_point = final_node[
                QuantizerPropagationStateGraph.
                INSERTION_POINT_DATA_NODE_ATTR]  # type: InsertionPoint
            insertion_info = InsertionInfo(
                OperationExecutionContext(
                    operator_name=insertion_point.ia_op_exec_context.
                    operator_name,
                    scope_in_model=insertion_point.ia_op_exec_context.
                    scope_in_model,
                    call_order=insertion_point.ia_op_exec_context.call_order,
                    tensor_metas=[None])
            )  # TODO: fix this, rethink InsertionInfo here and elsewhere
            self._potential_quantizers[
                insertion_point] = finished_prop_quantizer.potential_quant_configs
            retval[
                insertion_info] = finished_prop_quantizer.potential_quant_configs
        return retval

    def propagation_step(
        self, curr_prop_quantizer: PropagatingQuantizer,
        quant_prop_graph: QuantizerPropagationStateGraph
    ) -> QuantizerPropagationStateGraph:
        """Returns an updated curr_prop_quantizer state if the quantizer is not
           yet in its final (accepting) position, and None if the quantizer is in its
           final location.  The location before and after the step should correspond to
           some insertion point."""
        # TODO: full-fledged discrete finite automata approach? Switch to traversing a graph
        # consisting of insertion points only, with reversed edges holding associated operator data?
        curr_node_key = curr_prop_quantizer.current_location_node_key
        curr_node = quant_prop_graph.nodes[
            curr_prop_quantizer.current_location_node_key]
        curr_node_type = curr_node[
            QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
        assert curr_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT

        # Assumption: paths are at most 2 edges - either one edge to neighbouring insertion point
        # or one edge to operation and next edge to its own neighbouring insertion point.

        paths = quant_prop_graph.get_paths_to_immediately_dominating_insertion_points(
            curr_node_key)
        if not paths:
            prop_quantizer = quant_prop_graph.backtrack_propagation_until_accepting_location(
                curr_prop_quantizer)
            if prop_quantizer is not None:
                self._finished_propagating_quantizers.append(prop_quantizer)
            return quant_prop_graph

        surviving_prop_quantizers = []

        prop_quantizers_to_process = [curr_prop_quantizer]
        for _ in range(1, len(paths)):
            additional_prop_quantizer = quant_prop_graph.clone_propagating_quantizer(
                curr_prop_quantizer)
            prop_quantizers_to_process.append(additional_prop_quantizer)

        pqs_and_paths = zip(paths, prop_quantizers_to_process)
        for path, prop_quantizer in pqs_and_paths:
            status = self.check_transition_via_path(prop_quantizer, path,
                                                    quant_prop_graph)
            if status == TransitionStatus.SHOULD_NOT_TRANSITION:
                prop_quantizer = quant_prop_graph.backtrack_propagation_until_accepting_location(
                    prop_quantizer)
                if prop_quantizer is not None:
                    self._finished_propagating_quantizers.append(
                        prop_quantizer)
            elif status == TransitionStatus.SHOULD_TRANSITION:
                prop_quantizer = quant_prop_graph.propagate_quantizer_via_path(
                    prop_quantizer, path)
                surviving_prop_quantizers.append(prop_quantizer)
            elif status == TransitionStatus.SHOULD_MERGE:
                # The surviving quantizer will have its "affected edges" set extended
                # by the corresponding set of the merged quantizer. The assumption
                # here is that the surviving quantizer should never have to cross
                # such a "merge point" while backtracking to an accepting location.

                quant_prop_graph.merge_quantizer_into_path(
                    prop_quantizer, path)

        for prop_quantizer in surviving_prop_quantizers:
            self._active_propagating_quantizers_queue.appendleft(
                prop_quantizer)
        return quant_prop_graph

    def get_allowed_quantizer_configs_for_operator(
            self, quant_det_id: OperatorMetatype) -> List[QuantizerConfig]:
        return self._operator_allowed_qconfigs_map[quant_det_id]

    def set_allowed_quantization_types_for_operator_nodes(
            self, quant_prop_graph: QuantizerPropagationStateGraph):
        for node_key, node in quant_prop_graph.nodes.items():
            node_type = node[
                QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
            if node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                quant_det_id = node[
                    QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR]
                if quant_det_id is None:
                    warnings.warn(
                        "Unknown metatype for operator node: {}".format(
                            node_key))
                    trait = QuantizationTrait.QUANTIZATION_AGNOSTIC
                else:
                    trait = self._operator_quantization_trait_map[quant_det_id]
                node[QuantizerPropagationStateGraph.
                     QUANTIZATION_TRAIT_NODE_ATTR] = trait
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    node[QuantizerPropagationStateGraph.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] = \
                        self.get_allowed_quantizer_configs_for_operator(quant_det_id)
        return quant_prop_graph

    def get_operator_quantization_traits_map(
            self) -> Dict[OperatorMetatype, QuantizationTrait]:
        # TODO: ensure that there are no name collisions between ops in different torch subpackages with the same name
        retval = {}
        if self._hw_config is None:
            for op_meta in OPERATOR_METATYPES.registry_dict.values():
                retval[
                    op_meta] = QuantizationTrait.QUANTIZATION_AGNOSTIC  # Default value
            for trait, meta_list in DEFAULT_QUANT_TRAIT_TO_OP_DICT.items():
                for op_meta in meta_list:  # type: OperatorMetatype
                    retval[op_meta] = trait
        else:
            op_meta_vs_qconfs_map = self._hw_config.get_metatype_vs_quantizer_configs_map(
            )
            for op_meta, qconf_list in op_meta_vs_qconfs_map.items():
                if qconf_list is None:
                    trait = QuantizationTrait.QUANTIZATION_AGNOSTIC
                elif qconf_list:
                    trait = QuantizationTrait.INPUTS_QUANTIZABLE
                else:
                    trait = QuantizationTrait.NON_QUANTIZABLE
                retval[op_meta] = trait
        return retval

    def _get_operator_qconfigs_map(
            self) -> Dict[OperatorMetatype, List[QuantizerConfig]]:
        # TODO: ensure that there are no name collisions between ops in different torch subpackages with the same name
        retval = {}
        if self._hw_config is None:
            for op_meta in OPERATOR_METATYPES.registry_dict.values():
                retval[
                    op_meta] = QuantizationTrait.QUANTIZATION_AGNOSTIC  # Default value
            for trait, meta_list in DEFAULT_QUANT_TRAIT_TO_OP_DICT.items():
                if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
                    for op_meta in meta_list:  # type: OperatorMetatype
                        retval[op_meta] = self.DEFAULT_QUANTIZATION_TYPES
                else:
                    for op_meta in meta_list:  # type: OperatorMetatype
                        retval[op_meta] = []
        else:
            retval = self._hw_config.get_metatype_vs_quantizer_configs_map()
        return retval

    def debug_visualize(self, quant_prop_graph: QuantizerPropagationStateGraph,
                        dump_path: str):
        out_graph = quant_prop_graph.get_visualized_graph()
        active_ids_str = ", ".join(
            [str(pq.id) for pq in self._active_propagating_quantizers_queue])
        finished_ids_str = ", ".join(
            [str(pq.id) for pq in self._finished_propagating_quantizers])
        out_graph.graph['graph'] = {"label": "Propagating quantizers: {}\n" \
                                             "Finished quantizers: {}".format(active_ids_str, finished_ids_str),
                                    "labelloc": "t"}
        nx.drawing.nx_pydot.write_dot(out_graph, dump_path)

    def setup_initial_quantizers(
        self, quant_prop_graph: QuantizerPropagationStateGraph
    ) -> QuantizerPropagationStateGraph:
        """Determines the initial subset of the nodes that must be quantized
           and corresponding allowed quantization configs (possibly multiple) for each
           quantizer."""
        for node_key, node in quant_prop_graph.nodes.items():
            node_type = node[
                QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
            if node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                if node_key in quant_prop_graph.ignored_node_keys:
                    continue

                preds = list(quant_prop_graph.predecessors(node_key))
                if not preds:
                    continue  # TODO: remove this once module insertion points are included in the IP graph
                # Should be immediately preceded by an insertion point.
                pred_ip_key = preds[0]
                pred_node = quant_prop_graph.nodes[pred_ip_key]
                pred_node_type = pred_node[
                    QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
                assert pred_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT, \
                    "Invalid insertion point graph supplied for quantizer propagation!"

                if node[QuantizerPropagationStateGraph.
                        QUANTIZATION_TRAIT_NODE_ATTR] in [
                            QuantizationTrait.NON_QUANTIZABLE,
                            QuantizationTrait.QUANTIZATION_AGNOSTIC
                        ]:
                    continue

                quant_det_id = node[
                    QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR]
                qconf_list = self.get_allowed_quantizer_configs_for_operator(
                    quant_det_id)
                prop_quantizer = quant_prop_graph.add_propagating_quantizer(
                    qconf_list, pred_ip_key)
                self._active_propagating_quantizers_queue.appendleft(
                    prop_quantizer)

        return quant_prop_graph

    def check_branching_transition(
            self, quant_prop_graph: QuantizerPropagationStateGraph,
            prop_quantizer: PropagatingQuantizer,
            branching_node_key: str) -> Optional[TransitionStatus]:
        """If a propagating quantizer advances through a node that branches
           downwards, the branches neighbouring to the one that the propagating quantizer
           had just propagated from will have the precision of the quantizer imposed upon
           them.  This is not always desirable - we might want to keep some branches in
           higher precision than the others. For this reason, this function checks whether
           the quantizer may safely advance through a branching node based on the possible
           configs of the quantizers it might affect by doing so."""
        dom_op_node_keys = quant_prop_graph.get_quantizable_op_nodes_immediately_dominated_by_node(
            branching_node_key)
        master_possible_qconfigs = prop_quantizer.potential_quant_configs
        slave_possible_qconfigs_dict = {}
        for op_node_key in dom_op_node_keys:
            op_node = quant_prop_graph.nodes[op_node_key]
            affecting_prop_quantizers = op_node[
                QuantizerPropagationStateGraph.
                AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
            if not affecting_prop_quantizers:
                # The branch op is forced to be FP32 - should not proceed through the branch node.
                return TransitionStatus.SHOULD_NOT_TRANSITION
            slave_possible_qconfigs = affecting_prop_quantizers[
                0].potential_quant_configs
            slave_possible_qconfigs_dict[op_node_key] = slave_possible_qconfigs
        master_merged_qconfigs, \
        slave_merged_qconfigs_dict = self.get_merged_qconfigs(master_possible_qconfigs,
                                                              slave_possible_qconfigs_dict)
        if not master_merged_qconfigs:
            # This quantizer's precision does not encompass the precisions of quantizers
            # propagating through downward branches.
            return TransitionStatus.SHOULD_NOT_TRANSITION

        if self._propagation_strategy == PropagationStrategy.CONSERVATIVE:
            for op_node_key, slave_merged_qconfigs_list in slave_merged_qconfigs_dict.items(
            ):
                if len(slave_possible_qconfigs_dict[op_node_key]) != len(
                        slave_merged_qconfigs_list):
                    return TransitionStatus.SHOULD_NOT_TRANSITION

        return None

    def check_transition_via_path(
            self, prop_quantizer: PropagatingQuantizer, path: List,
            quant_prop_graph: QuantizerPropagationStateGraph
    ) -> TransitionStatus:
        """Determines which action should be taken regarding the
           prop_quantizer's propagation via path, which may be one of many possible
           propagation paths."""
        for from_node_key, to_node_key in path:
            from_node = quant_prop_graph.nodes[from_node_key]
            if len(list(quant_prop_graph.successors(from_node_key))) > 1:
                # If a quantizer simply passes up through a downward-branching node, it may spoil the
                # precision for operations on neighbouring branches. Consider a 4-bit quantizer rising
                # through a branch node and an 8-bit quantizer arriving at the same node later. Therefore,
                # prior to allowing the quantizer to pass through a branching node we need to ensure that
                # the precision of the quantizer is a superset of precisions of the first non-quantization agnostic
                # operations on each branch.
                status = self.check_branching_transition(
                    quant_prop_graph, prop_quantizer, from_node_key)
                if status is not None:
                    return status

            from_node_type = from_node[
                QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
            if from_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
                trait = from_node[QuantizerPropagationStateGraph.
                                  QUANTIZATION_TRAIT_NODE_ATTR]
                if trait in [
                        QuantizationTrait.NON_QUANTIZABLE,
                        QuantizationTrait.INPUTS_QUANTIZABLE
                ]:
                    return TransitionStatus.SHOULD_NOT_TRANSITION
            edge = quant_prop_graph.edges[from_node_key, to_node_key]
            potential_quantizers = edge[QuantizerPropagationStateGraph.
                                        AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
            if potential_quantizers:
                # Assuming that multiple affecting quantizers should all have the same quantization config
                # by construction
                if prop_quantizer.potential_quant_configs == potential_quantizers[
                        0].potential_quant_configs:
                    return TransitionStatus.SHOULD_MERGE
                return TransitionStatus.SHOULD_NOT_TRANSITION

            from_node_type = from_node[
                QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
            if from_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT:
                potential_quantizers = from_node[
                    QuantizerPropagationStateGraph.
                    AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
                if potential_quantizers:
                    # Affecting quantizers should have the same configs by construction, so we only
                    # check the first
                    if prop_quantizer.potential_quant_configs == potential_quantizers[
                            0].potential_quant_configs:
                        return TransitionStatus.SHOULD_MERGE
        return TransitionStatus.SHOULD_TRANSITION

    def get_merged_qconfigs(
        self, master_potential_qconfigs_list: List[QuantizerConfig],
        slave_potential_qconfigs_dict: Dict[str, List[QuantizerConfig]]
    ) -> Tuple[List[QuantizerConfig], Dict[str, QuantizerConfig]]:
        """Returns potential qconfigs lists for 'master' and 'slave' quantizers
        that are compatible with each other. Compatibility is decided in terms of
        master quantizer having configs which all have higher precision than all the
        slave potential quantizer configs."""
        final_master_merged_qconfigs_list = deepcopy(
            master_potential_qconfigs_list)
        curr_slave_merged_qconfigs_dict = deepcopy(
            slave_potential_qconfigs_dict)
        # TODO: implement variant solutions, i.e. for each set of resultant merged
        # master potential qconfig lists we have, in general, different merged slave potential
        # config lists. Currently greedy approach is used.
        for m_qconfig in master_potential_qconfigs_list:
            should_persist_slave_merged_qconfigs_dict = True
            candidate_slave_merged_qconfigs_dict = deepcopy(
                curr_slave_merged_qconfigs_dict)
            for node_key, s_qconfig_list in curr_slave_merged_qconfigs_dict.items(
            ):
                for s_qconfig in s_qconfig_list:
                    if m_qconfig < s_qconfig and s_qconfig in candidate_slave_merged_qconfigs_dict[
                            node_key]:
                        candidate_slave_merged_qconfigs_dict[node_key].remove(
                            s_qconfig)
            for _, s_qconfig_list in candidate_slave_merged_qconfigs_dict.items(
            ):
                if not s_qconfig_list:
                    # No options left for slave configs on one of the branches to accomodate the master
                    # config - this master config cannot be used to be merged into.
                    final_master_merged_qconfigs_list.remove(m_qconfig)
                    should_persist_slave_merged_qconfigs_dict = False
                    break
            if should_persist_slave_merged_qconfigs_dict:
                curr_slave_merged_qconfigs_dict = candidate_slave_merged_qconfigs_dict
        if not final_master_merged_qconfigs_list:
            return [], {}
        return final_master_merged_qconfigs_list, curr_slave_merged_qconfigs_dict

    def get_finished_propagating_quantizers(self):
        return self._finished_propagating_quantizers

    def get_active_propagating_quantizers_queue(self):
        return self._active_propagating_quantizers_queue