def test_add_propagating_quantizer(self, mock_qp_graph): ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)] target_node_key = "F" target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key) prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key) assert prop_quant.potential_quant_configs == ref_qconf_list assert prop_quant.current_location_node_key == target_ip_node_key assert prop_quant.affected_ip_nodes == {target_ip_node_key} assert prop_quant.last_accepting_location_node_key is None assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)} assert not prop_quant.propagation_path for node_key, node in mock_qp_graph.nodes.items(): if node_key == target_node_key: assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant] elif node_key == target_ip_node_key: assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant else: assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None target_ip_node = mock_qp_graph.nodes[target_ip_node_key] assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant for from_node, to_node, edge_data in mock_qp_graph.edges.data(): if (from_node, to_node) == (target_ip_node_key, target_node_key): assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant] else: assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] with pytest.raises(RuntimeError): _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, InsertionPointGraph.get_post_hook_node_key(target_node_key))
def test_quantization_configs__custom(): model = BasicConvTestModel() config = get_basic_quantization_config() config['compression'].update({ "weights": { "mode": "asymmetric", "per_channel": True, "signed": False, "bits": 4 }, "activations": { "mode": "asymmetric", "bits": 4, "signed": True, }, }) reset_context('orig') reset_context('quantized_graphs') compression_algo = create_compression_algorithm(model, config) weight_quantizers, activation_quantizers = split_quantizers( compression_algo.model) ref_weight_qconfig = QuantizerConfig(4, QuantizationMode.ASYMMETRIC, None, True, None, True) for wq in weight_quantizers: compare_qconfigs(ref_weight_qconfig, wq.config) ref_activation_qconfig = QuantizerConfig(4, QuantizationMode.ASYMMETRIC, True, False, None, False) for wq in activation_quantizers: compare_qconfigs(ref_activation_qconfig, wq.config)
def test_remove_propagating_quantizer(self, mock_qp_graph, start_target_nodes_for_two_quantizers): start_ip_node_key_remove = start_target_nodes_for_two_quantizers[0] target_ip_node_key_remove = start_target_nodes_for_two_quantizers[1] start_ip_node_key_keep = start_target_nodes_for_two_quantizers[2] target_ip_node_key_keep = start_target_nodes_for_two_quantizers[3] # From "target" to "start" since propagation direction is inverse to edge direction # Only take one path out of possible paths for this test rev_path_remove = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key_remove, start_ip_node_key_remove)[0] rev_path_keep = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key_keep, start_ip_node_key_keep)[0] prop_quant_to_remove = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key_remove) prop_quant_to_remove = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_remove, rev_path_remove) prop_quant_to_keep = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key_keep) prop_quant_to_keep = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_keep, rev_path_keep) affected_ip_nodes = deepcopy(prop_quant_to_remove.affected_ip_nodes) affected_op_nodes = deepcopy(prop_quant_to_remove.affected_operator_nodes) affected_edges = deepcopy(prop_quant_to_keep.affected_edges) last_location = prop_quant_to_remove.current_location_node_key ref_quant_to_keep_state_dict = deepcopy(prop_quant_to_keep.__dict__) mock_qp_graph.remove_propagating_quantizer(prop_quant_to_remove) last_node = mock_qp_graph.nodes[last_location] assert last_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None for ip_node_key in affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for op_node_key in affected_op_nodes: node = mock_qp_graph.nodes[op_node_key] assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert prop_quant_to_remove not in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] assert prop_quant_to_keep.__dict__ == ref_quant_to_keep_state_dict for ip_node_key in prop_quant_to_keep.affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert prop_quant_to_keep in node[ QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in prop_quant_to_keep.affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert prop_quant_to_keep in edge[ QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG: pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('C')) pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('D')) qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig()], [None, None], InsertionPointGraph.get_post_hook_node_key('B')) qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('E')) return qpsg
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG: pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('F')) qpsg.propagate_quantizer_via_path(pq_1, [ (InsertionPointGraph.get_post_hook_node_key('C'), InsertionPointGraph.get_pre_hook_node_key('F')) ]) _ = qpsg.add_propagating_quantizer([QuantizerConfig(bits=6)], InsertionPointGraph.get_pre_hook_node_key('F')) return qpsg
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG: # This case will fail if, after going depth-first through the 'D' branch of the graph, # the merge traversal function state is not reset (which is incorrect behavior) # when starting to traverse the 'C' branch. qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=0)) qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('I', in_port_id=1)) qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('C')) qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('D')) return qpsg
def test_quantize_range_init_sets_correct_scale_shapes( quantizer_range_init_test_struct: Tuple[QRISSTS, str]): test_struct = quantizer_range_init_test_struct[0] initializer_type = quantizer_range_init_test_struct[1] for quantization_mode in [ QuantizationMode.SYMMETRIC, QuantizationMode.ASYMMETRIC ]: qconfig = QuantizerConfig(mode=quantization_mode, per_channel=test_struct.per_channel, is_weights=test_struct.is_weights, input_shape=test_struct.input_shape) q_cls = QUANTIZATION_MODULES.get(quantization_mode) quantizer = q_cls(qconfig) # type: BaseQuantizer range_init_config = RangeInitConfig(init_type=initializer_type, num_init_samples=1) collector = range_init_config.generate_stat_collector( reduction_shapes={tuple(quantizer.scale_shape)}) collector.register_input(torch.ones(test_struct.input_shape)) stat = collector.get_statistics()[tuple(quantizer.scale_shape)] minmax_values = MinMaxTensorStatistic.from_stat(stat) quantizer.apply_minmax_init(min_values=minmax_values.min_values, max_values=minmax_values.max_values) assert quantizer.scale_shape == test_struct.ref_scale_shape if quantization_mode == QuantizationMode.SYMMETRIC: assert list(quantizer.scale.shape) == test_struct.ref_scale_shape elif quantization_mode == QuantizationMode.ASYMMETRIC: assert list( quantizer.input_low.shape) == test_struct.ref_scale_shape assert list( quantizer.input_range.shape) == test_struct.ref_scale_shape else: assert False # options above should be exhaustive
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG: pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)], InsertionPointGraph.get_pre_hook_node_key('C')) pq_2 = qpsg.add_propagating_quantizer([QuantizerConfig(per_channel=True)], InsertionPointGraph.get_pre_hook_node_key('D')) qpsg.merge_quantizers_for_branching_node([pq_1, pq_2], [QuantizerConfig(per_channel=True)], [None, None], InsertionPointGraph.get_post_hook_node_key('B')) pq_3 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('E')) paths = get_edge_paths_for_propagation(qpsg, InsertionPointGraph.get_pre_hook_node_key('D'), InsertionPointGraph.get_pre_hook_node_key('E')) path = paths[0] qpsg.propagate_quantizer_via_path(pq_3, path) return qpsg
def test_clone_propagating_quantizer(self, mock_qp_graph, start_target_nodes): start_ip_node_key = start_target_nodes[0] target_ip_node_key = start_target_nodes[1] # From "target" to "start" since propagation direction is inverse to edge direction # Only take one path out of possible paths for this test rev_path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0] ref_prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) prop_quant = mock_qp_graph.propagate_quantizer_via_path(ref_prop_quant, rev_path) cloned_prop_quant = mock_qp_graph.clone_propagating_quantizer(prop_quant) assert cloned_prop_quant.affected_ip_nodes == prop_quant.affected_ip_nodes assert cloned_prop_quant.affected_edges == prop_quant.affected_edges assert cloned_prop_quant.propagation_path == prop_quant.propagation_path assert cloned_prop_quant.current_location_node_key == prop_quant.current_location_node_key for ip_node_key in prop_quant.affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert cloned_prop_quant in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in prop_quant.affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert cloned_prop_quant in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] # The cloned quantizer had not been put into any IP (cannot have multiple PQs in one IP right now) mock_qp_graph.skip_check = True
def test_backtrack_propagation_until_accepting_location(self, start_target_accepting_nodes, mock_qp_graph): start_ip_node_key = start_target_accepting_nodes[0] target_ip_node_key = start_target_accepting_nodes[1] forced_last_accepting_location = start_target_accepting_nodes[2] prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) ref_affected_edges = deepcopy(prop_quant.affected_edges) # Here, the tested graph should have such a structure that there is only one path from target to start path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0] prop_quant = mock_qp_graph.propagate_quantizer_via_path(prop_quant, path) prop_quant.last_accepting_location_node_key = forced_last_accepting_location if forced_last_accepting_location is not None: resulting_path = get_edge_paths_for_propagation(mock_qp_graph, forced_last_accepting_location, start_ip_node_key)[0] ref_affected_edges.update(set(resulting_path)) prop_quant = mock_qp_graph.backtrack_propagation_until_accepting_location(prop_quant) assert prop_quant.current_location_node_key == forced_last_accepting_location assert prop_quant.affected_edges == ref_affected_edges assert prop_quant.propagation_path == resulting_path target_node = mock_qp_graph.nodes[target_ip_node_key] accepting_node = mock_qp_graph.nodes[forced_last_accepting_location] if forced_last_accepting_location != target_ip_node_key: assert target_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None assert target_ip_node_key not in prop_quant.affected_ip_nodes assert accepting_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant
def test_quantizers_can_propagate_via_path(self, start_target_nodes, mock_qp_graph): start_ip_node_key = start_target_nodes[0] target_ip_node_key = start_target_nodes[1] # From "target" to "start" since propagation direction is inverse to edge direction rev_paths = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key) for path in rev_paths: working_graph = deepcopy(mock_qp_graph) ref_prop_quant = working_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) ref_affected_edges = deepcopy(ref_prop_quant.affected_edges) ref_affected_edges.update(set(path)) ref_affected_ip_nodes = deepcopy(ref_prop_quant.affected_ip_nodes) prop_quant = working_graph.propagate_quantizer_via_path(ref_prop_quant, path) final_node_key, _ = path[-1] for from_node_key, to_node_key in path: edge_data = working_graph.edges[from_node_key, to_node_key] assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [ref_prop_quant] to_node = working_graph.nodes[to_node_key] if to_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: assert to_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None from_node = working_graph.nodes[from_node_key] if from_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: ref_affected_ip_nodes.add(from_node_key) working_graph.run_consistency_check() final_node_key, _ = path[-1] final_node = working_graph.nodes[final_node_key] assert final_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == ref_prop_quant assert prop_quant.current_location_node_key == final_node_key assert prop_quant.propagation_path == path assert prop_quant.affected_edges == ref_affected_edges assert prop_quant.affected_ip_nodes == ref_affected_ip_nodes
def test_quantization_configs__with_defaults(): model = BasicConvTestModel() config = get_basic_quantization_config() _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) assert isinstance(compression_ctrl, QuantizationController) weight_quantizers = compression_ctrl.weight_quantizers activation_quantizers = compression_ctrl.non_weight_quantizers ref_weight_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, True) for wq in weight_quantizers.values(): compare_qconfigs(ref_weight_qconfig, wq) ref_activation_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, False) for wq in activation_quantizers.values(): compare_qconfigs(ref_activation_qconfig, wq)
def test_quantize_range_init_sets_correct_scale_shapes( quantizer_range_init_test_struct: Tuple[QRISSTS, str]): test_struct = quantizer_range_init_test_struct[0] initializer_type = quantizer_range_init_test_struct[1] for quantization_mode in [ QuantizationMode.SYMMETRIC, QuantizationMode.ASYMMETRIC ]: qconfig = QuantizerConfig(mode=quantization_mode, per_channel=test_struct.per_channel, is_weights=test_struct.is_weights, input_shape=test_struct.input_shape) q_cls = QUANTIZATION_MODULES.get(quantization_mode) quantizer = q_cls(qconfig) # type: BaseQuantizer init_config = {"type": initializer_type, "num_init_steps": 1} initializer = RangeInitializerFactory.create(init_config, quantizer, "") initializer.register_input(torch.ones(test_struct.input_shape)) with torch.no_grad(): initializer.apply_init() assert quantizer.scale_shape == test_struct.ref_scale_shape if quantization_mode == QuantizationMode.SYMMETRIC: assert list(quantizer.scale.shape) == test_struct.ref_scale_shape elif quantization_mode == QuantizationMode.ASYMMETRIC: assert list( quantizer.input_low.shape) == test_struct.ref_scale_shape assert list( quantizer.input_range.shape) == test_struct.ref_scale_shape else: assert False # options above should be exhaustive
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict, for_weights=False): bits = quantization_subdict["bits"] mode = HWConfig.get_quantization_mode_from_config_value(quantization_subdict["mode"]) is_per_channel = HWConfig.get_is_per_channel_from_config_value(quantization_subdict["granularity"]) signedness_to_force = None if 'level_low' in quantization_subdict and 'level_high' in quantization_subdict: signedness_to_force = False if mode == QuantizationMode.SYMMETRIC: if quantization_subdict['level_low'] < 0 < quantization_subdict['level_high']: signedness_to_force = True true_level_low, true_level_high, _ = SymmetricQuantizer.calculate_level_ranges(bits, True) else: signedness_to_force = True true_level_low, true_level_high, _ = AsymmetricQuantizer.calculate_level_ranges(bits) assert quantization_subdict['level_low'] == true_level_low, \ "Invalid value of quantizer parameter `level_low`.\ The parameter must be consistent with other parameters!" assert quantization_subdict['level_high'] == true_level_high, \ "Invalid value of quantizer parameter `level_high`.\ The parameter must be consistent with other parameters!" return QuantizerConfig(bits=bits, mode=mode, per_channel=is_per_channel, signedness_to_force=signedness_to_force, is_weights=for_weights)
def test_check_hawq_dump(mocker, tmp_path): tensor1 = torch.Tensor([1]) tensor2 = torch.Tensor([2]) qconf1 = QuantizerConfig(bits=2) qconf2 = QuantizerConfig(bits=4) id_ = 0 quantizer_configurations = [[qconf1, qconf1], [qconf2, qconf2]] flops_per_config = [tensor1.item(), tensor2.item()] choosen_config_index = id_ configuration_metric = [tensor1, tensor2] perturbations = Perturbations() perturbations.add(id_, qconf1, tensor1) perturbations.add(id_, qconf2, tensor2) perturbations.add(id_ + 1, qconf1, tensor2) perturbations.add(id_ + 1, qconf2, tensor1) observer1 = PerturbationObserver(mocker.stub()) observer1.perturbation = tensor1 observer1.numels = id_ observer1.input_norm = id_ observer2 = PerturbationObserver(mocker.stub()) observer2.perturbation = tensor2 observer2.numels = id_ observer2.input_norm = id_ weight_observers = [observer1, observer2] traces_per_layer = TracesPerLayer(torch.cat((tensor1, tensor2))) set_debug_log_dir(str(tmp_path)) hawq_debugger = HAWQDebugger(quantizer_configurations, perturbations, quantizer_configurations, [weight_observers, weight_observers], traces_per_layer, [qconf1.bits, qconf2.bits]) hawq_debugger.dump_metric_MB(configuration_metric) hawq_debugger.dump_metric_flops(configuration_metric, flops_per_config, choosen_config_index) hawq_debugger.dump_avg_traces() hawq_debugger.dump_density_of_quantization_noise() hawq_debugger.dump_perturbations_ratio() test_dir = tmp_path / Path('hawq_dumps') num_dump_files = len([ name for name in os.listdir(test_dir) if os.path.isfile(os.path.join(test_dir, name)) ]) assert num_dump_files == 6
def validate_spy(self): bitwidth_list = self.set_chosen_config_spy.call_args[0][1] assert len(bitwidth_list) == self.n_weight_quantizers # with default compression ratio = 1.5 all precisions should be different from the default one assert set(bitwidth_list) != {QuantizerConfig().bits} init_data_loader = self.hessian_trace_estimator_spy.call_args[0][5] expected_batch_size = self.batch_size_init if self.batch_size_init else self.batch_size assert init_data_loader.batch_size == expected_batch_size
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict): bits = quantization_subdict["bits"] mode = HWConfig.get_quantization_mode_from_config_value( quantization_subdict["mode"]) is_per_channel = HWConfig.get_is_per_channel_from_config_value( quantization_subdict["granularity"]) return QuantizerConfig(bits=bits, mode=mode, per_channel=is_per_channel)
def test_quantization_configs__with_defaults(): model = BasicConvTestModel() config = get_basic_quantization_config() reset_context('orig') reset_context('quantized_graphs') compression_algo = create_compression_algorithm(deepcopy(model), config) weight_quantizers, activation_quantizers = split_quantizers( compression_algo.model) ref_weight_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, True) for wq in weight_quantizers: compare_qconfigs(ref_weight_qconfig, wq.config) ref_activation_qconfig = QuantizerConfig(8, QuantizationMode.SYMMETRIC, None, False, None, False) for wq in activation_quantizers: compare_qconfigs(ref_activation_qconfig, wq.config)
def test_quantization_configs__custom(): model = BasicConvTestModel() config = get_quantization_config_without_range_init() config['compression'].update({ "weights": { "mode": "asymmetric", "per_channel": True, "bits": 4 }, "activations": { "mode": "asymmetric", "bits": 4, "signed": True, }, }) config['target_device'] = 'NONE' _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) assert isinstance(compression_ctrl, QuantizationController) weight_quantizers = compression_ctrl.weight_quantizers activation_quantizer_infos = compression_ctrl.non_weight_quantizers ref_weight_qconfig = QuantizerConfig(bits=4, mode=QuantizationMode.ASYMMETRIC, signedness_to_force=None, per_channel=True, input_shape=None, is_weights=True) for wq in weight_quantizers.values(): compare_qconfigs(ref_weight_qconfig, wq) ref_activation_qconfig = QuantizerConfig(bits=4, mode=QuantizationMode.ASYMMETRIC, signedness_to_force=True, per_channel=False, input_shape=None, is_weights=False) for aq_info in activation_quantizer_infos.values(): compare_qconfigs(ref_activation_qconfig, aq_info.quantizer_module_ref)
def generate_qp(scope_str: str, target: QuantizerGroup, in_port_id: int = None) -> SingleConfigQuantizationPoint: if target is QuantizerGroup.WEIGHTS: ip = InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP, module_scope=Scope.from_str(scope_str)) elif target is QuantizerGroup.ACTIVATIONS: ip = InsertionPoint( InsertionType.OPERATOR_POST_HOOK if in_port_id is None else InsertionType.OPERATOR_PRE_HOOK, ia_op_exec_context=InputAgnosticOperationExecutionContext.from_str( scope_str), input_port_id=in_port_id) else: raise RuntimeError() return SingleConfigQuantizationPoint(ip, QuantizerConfig())
def test_flops(config_creator, ref_values): class ConvLinear(nn.Module): def __init__(self): super().__init__() self.conv1 = create_conv(1, 1, 2, -1, -2) self.fc = nn.Linear(3, 6) def forward(self, x): return self.fc(self.conv1(x)) config = config_creator() model, compression_ctrl = create_compressed_model_and_algo_for_test(ConvLinear(), config) quantizers = compression_ctrl.weight_quantizers handler = WeightQuantizersHandler(model, quantizers, HWPrecisionConstraints(True)) flops_counter = CompressionRatioCalculator(model, handler) assert flops_counter.ratio_for_bits_configuration([4, 8]) == ref_values[0] assert flops_counter.ratio_for_bits_configuration([8, 4]) == ref_values[1] assert flops_counter.ratio_limits([4, 8]) == ref_values[2] assert flops_counter.ratio_limits([2, 4, 8]) == ref_values[3] constraints = HWPrecisionConstraints(True).add(list(quantizers)[0], [QuantizerConfig(bits=8)]) assert flops_counter.ratio_limits([2, 8], constraints) == ref_values[4]
if __name__ == '__main__': for input_name, input_size, gpu_runs in TEST_PARAMS_STRUCT: print("CUDA " + input_name) print("------------------------------------------------") print("Pytorch Symmetric (cuda 0) impl:") print("input size: {0}".format(input_size)) run_profile( ReferenceQuantize(NBITS).cuda(), input_size, 'cuda', gpu_runs) print() print("Custom Symmetric (cuda 0 ) impl:") print("input size: {0}".format(input_size)) run_profile( SymmetricQuantizer(QuantizerConfig( QuantizationParams(bits=NBITS))).cuda(), input_size, 'cuda', gpu_runs) print() print("Pytorch Symmetric Per Weight Channel (cuda 0) impl:") print("input size: {0}".format(input_size)) run_profile( ReferenceQuantize(NBITS, input_shape=input_size, per_channel=True, is_weights=True).cuda(), input_size, 'cuda', gpu_runs) print() print("Custom Symmetric Per Weight Channel (cuda 0 ) impl") print("input size: {0}".format(input_size))
# nx_graph is expected to have version-agnostic operator names already for k, attrs in nx_graph.nodes.items(): attrs = {k: str(v) for k, v in attrs.items()} load_attrs = {k: str(v).strip('"') for k, v in load_graph.nodes[k].items()} assert attrs == load_attrs assert load_graph.nodes.keys() == nx_graph.nodes.keys() assert nx.DiGraph(load_graph).edges == nx_graph.edges QuantizeConfig = namedtuple('QuantizeConfig', ['quantizer', 'graph_dir']) QUANTIZERS = [ QuantizeConfig(lambda _, is_weights=False, input_shape=None: SymmetricQuantizer( QuantizerConfig(signedness_to_force=is_weights, is_weights=is_weights, input_shape=input_shape)), 'symmetric'), QuantizeConfig(lambda _, is_weights, input_shape=None: AsymmetricQuantizer(QuantizerConfig()), 'asymmetric') ] @pytest.fixture(scope='function', params=QUANTIZERS, ids=[pair.graph_dir for pair in QUANTIZERS]) def _quantize_config(request): config = request.param graph_dir = os.path.join('quantized', config.graph_dir) return QuantizeConfig(config.quantizer, graph_dir) def default_forward_fn(model, input_size_): device = next(model.parameters()).device return model(torch.zeros(input_size_).to(device))
def test_onnx_export_to_quantize_dequantize_per_channel(): # SYMMETRIC q_config = QuantizerConfig(input_shape=(2, 64, 15, 10), bits=8, mode=QuantizationMode.SYMMETRIC, signedness_to_force=None, per_channel=True) sym_quantizer = SymmetricQuantizer(q_config) # pylint: disable=protected-access sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS x = torch.rand((2, 64, 15, 10)) sym_quantizer.run_export_quantization(x) q_config = QuantizerConfig(bits=8, mode=QuantizationMode.SYMMETRIC, signedness_to_force=None, per_channel=False) sym_quantizer = SymmetricQuantizer(q_config) # pylint: disable=protected-access sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS x = torch.rand((2, 64, 15, 10)) sym_quantizer.run_export_quantization(x) q_config = QuantizerConfig(input_shape=(2, 64, 15, 10), bits=8, mode=QuantizationMode.SYMMETRIC, signedness_to_force=None, per_channel=True) sym_quantizer = SymmetricQuantizer(q_config) # pylint: disable=protected-access sym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS sym_quantizer.scale = torch.nn.Parameter(torch.rand(1, 64, 1, 1)) x = torch.rand((2, 64, 15, 10)) try: sym_quantizer.run_export_quantization(x) except RuntimeError as e: assert str( e) == "PyTorch v1.5.0 export to ONNX using QuantizeLinear-DequantizeLinear " \ "doesn't support per channel quantization" # ASYMMETRIC q_config = QuantizerConfig(input_shape=(2, 64, 15, 10), bits=8, mode=QuantizationMode.ASYMMETRIC, signedness_to_force=None, per_channel=True) assym_quantizer = AsymmetricQuantizer(q_config) # pylint: disable=protected-access assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS x = torch.rand((2, 64, 15, 10)) assym_quantizer.run_export_quantization(x) q_config = QuantizerConfig(bits=8, mode=QuantizationMode.ASYMMETRIC, signedness_to_force=None, per_channel=False) assym_quantizer = AsymmetricQuantizer(q_config) # pylint: disable=protected-access assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS x = torch.rand((2, 64, 15, 10)) assym_quantizer.run_export_quantization(x) q_config = QuantizerConfig(input_shape=(2, 64, 15, 10), bits=8, mode=QuantizationMode.ASYMMETRIC, signedness_to_force=None, per_channel=True) assym_quantizer = AsymmetricQuantizer(q_config) # pylint: disable=protected-access assym_quantizer._export_mode = QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS sym_quantizer.scale = torch.nn.Parameter(torch.rand(1, 64, 1, 1)) x = torch.rand((2, 64, 15, 10)) try: assym_quantizer.run_export_quantization(x) except RuntimeError as e: assert str( e) == "PyTorch v1.5.0 export to ONNX using QuantizeLinear-DequantizeLinear" \ " doesn't support per channel quantization"
def get_mock_precision_constraints(constraints, ordered_weight_keys): hw_precision_constraints = HWPrecisionConstraints(True) for key, bits in zip(ordered_weight_keys, constraints): bit_constraints = [QuantizerConfig(bits=bitwidth) for bitwidth in bits] hw_precision_constraints.add(key, bit_constraints) return hw_precision_constraints
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG: qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('B')) qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key('E')) return qpsg
class TestQuantizerPropagationStateGraph: #pylint:disable=too-many-public-methods @staticmethod @pytest.fixture() def mock_qp_graph(): ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph()) qpsg = QPSG(ip_graph) qpsg.skip_check = False yield qpsg if not qpsg.skip_check: qpsg.run_consistency_check() def test_build_quantizer_propagation_state_graph_from_ip_graph(self): ip_graph = InsertionPointGraph(get_two_branch_mock_model_graph()) quant_prop_graph = QPSG(ip_graph) assert len(ip_graph.nodes) == len(quant_prop_graph.nodes) assert len(ip_graph.edges) == len(quant_prop_graph.edges) for ip_graph_node_key, ip_graph_node in ip_graph.nodes.items(): qpg_node = quant_prop_graph.nodes[ip_graph_node_key] assert qpg_node[QPSG.NODE_TYPE_NODE_ATTR] == QPSG.ipg_node_type_to_qpsg_node_type(ip_graph_node[ InsertionPointGraph.NODE_TYPE_NODE_ATTR]) qpg_node_type = qpg_node[QPSG.NODE_TYPE_NODE_ATTR] if qpg_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: assert qpg_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] assert qpg_node[QPSG.INSERTION_POINT_DATA_NODE_ATTR] == ip_graph_node[ InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR] elif qpg_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: assert not qpg_node[QPSG.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] assert qpg_node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] == QuantizationTrait.NON_QUANTIZABLE assert not qpg_node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node, to_node, edge_data in ip_graph.edges(data=True): qpg_edge_data = quant_prop_graph.edges[from_node, to_node] assert not qpg_edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for key, value in edge_data.items(): assert qpg_edge_data[key] == value quant_prop_graph.run_consistency_check() def test_add_propagating_quantizer(self, mock_qp_graph): ref_qconf_list = [QuantizerConfig(), QuantizerConfig(bits=6)] target_node_key = "F" target_ip_node_key = InsertionPointGraph.get_pre_hook_node_key(target_node_key) prop_quant = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, target_ip_node_key) assert prop_quant.potential_quant_configs == ref_qconf_list assert prop_quant.current_location_node_key == target_ip_node_key assert prop_quant.affected_ip_nodes == {target_ip_node_key} assert prop_quant.last_accepting_location_node_key is None assert prop_quant.affected_edges == {(target_ip_node_key, target_node_key)} assert not prop_quant.propagation_path for node_key, node in mock_qp_graph.nodes.items(): if node_key == target_node_key: assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant] elif node_key == target_ip_node_key: assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant else: assert not node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] if node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: assert node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None target_ip_node = mock_qp_graph.nodes[target_ip_node_key] assert target_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant for from_node, to_node, edge_data in mock_qp_graph.edges.data(): if (from_node, to_node) == (target_ip_node_key, target_node_key): assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant] else: assert not edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] with pytest.raises(RuntimeError): _ = mock_qp_graph.add_propagating_quantizer(ref_qconf_list, InsertionPointGraph.get_post_hook_node_key(target_node_key)) START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES = [ # Non-branching case - starting from "E" pre-hook (InsertionPointGraph.get_pre_hook_node_key("E"), [[(InsertionPointGraph.get_post_hook_node_key("C"), InsertionPointGraph.get_pre_hook_node_key("E"))]]), # Non-branching case - starting from "C" post-hook (InsertionPointGraph.get_post_hook_node_key("C"), [[("C", InsertionPointGraph.get_post_hook_node_key("C")), (InsertionPointGraph.get_pre_hook_node_key("C"), "C")]]), # Branching case - starting from "F" pre-hook port 0 (InsertionPointGraph.get_pre_hook_node_key("F"), [[(InsertionPointGraph.get_post_hook_node_key("D"), InsertionPointGraph.get_pre_hook_node_key("F"))]]), # Branching case - starting from "F" pre-hook port 1 (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1), [[(InsertionPointGraph.get_post_hook_node_key("E"), InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1))]]), ] @staticmethod @pytest.fixture(params=START_IP_NODES_AND_PATHS_TO_DOMINATING_IP_NODES) def start_ip_node_and_path_to_dominating_node(request): return request.param def test_get_paths_to_immediately_dominating_insertion_points(self, start_ip_node_and_path_to_dominating_node, mock_qp_graph): start_node = start_ip_node_and_path_to_dominating_node[0] ref_paths = start_ip_node_and_path_to_dominating_node[1] test_paths = mock_qp_graph.get_paths_to_immediately_dominating_insertion_points(start_node) def get_cat_path_list(path_list: List[List[Tuple[str, str]]]): str_paths = [[str(edge[0]) +' -> ' + str(edge[1]) for edge in path] for path in path_list] cat_paths = [';'.join(path) for path in str_paths] return cat_paths assert Counter(get_cat_path_list(ref_paths)) == Counter(get_cat_path_list(test_paths)) START_TARGET_NODES = [ (InsertionPointGraph.get_pre_hook_node_key("H"), InsertionPointGraph.get_post_hook_node_key("G")), (InsertionPointGraph.get_pre_hook_node_key("H"), InsertionPointGraph.get_pre_hook_node_key("F")), (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1), InsertionPointGraph.get_pre_hook_node_key("E")), (InsertionPointGraph.get_pre_hook_node_key("F"), InsertionPointGraph.get_post_hook_node_key("B")), ] @staticmethod @pytest.fixture(params=START_TARGET_NODES) def start_target_nodes(request): return request.param @pytest.mark.dependency(name="propagate_via_path") def test_quantizers_can_propagate_via_path(self, start_target_nodes, mock_qp_graph): start_ip_node_key = start_target_nodes[0] target_ip_node_key = start_target_nodes[1] # From "target" to "start" since propagation direction is inverse to edge direction rev_paths = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key) for path in rev_paths: working_graph = deepcopy(mock_qp_graph) ref_prop_quant = working_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) ref_affected_edges = deepcopy(ref_prop_quant.affected_edges) ref_affected_edges.update(set(path)) ref_affected_ip_nodes = deepcopy(ref_prop_quant.affected_ip_nodes) prop_quant = working_graph.propagate_quantizer_via_path(ref_prop_quant, path) final_node_key, _ = path[-1] for from_node_key, to_node_key in path: edge_data = working_graph.edges[from_node_key, to_node_key] assert edge_data[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [ref_prop_quant] to_node = working_graph.nodes[to_node_key] if to_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: assert to_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None from_node = working_graph.nodes[from_node_key] if from_node[QPSG.NODE_TYPE_NODE_ATTR] == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: ref_affected_ip_nodes.add(from_node_key) working_graph.run_consistency_check() final_node_key, _ = path[-1] final_node = working_graph.nodes[final_node_key] assert final_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == ref_prop_quant assert prop_quant.current_location_node_key == final_node_key assert prop_quant.propagation_path == path assert prop_quant.affected_edges == ref_affected_edges assert prop_quant.affected_ip_nodes == ref_affected_ip_nodes START_TARGET_ACCEPTING_NODES = [ (InsertionPointGraph.get_pre_hook_node_key("H"), InsertionPointGraph.get_pre_hook_node_key("G"), InsertionPointGraph.get_post_hook_node_key("G")), (InsertionPointGraph.get_pre_hook_node_key("G"), InsertionPointGraph.get_post_hook_node_key("F"), InsertionPointGraph.get_post_hook_node_key("F")), (InsertionPointGraph.get_pre_hook_node_key("F", in_port_id=1), InsertionPointGraph.get_pre_hook_node_key("C"), InsertionPointGraph.get_post_hook_node_key("C")), (InsertionPointGraph.get_pre_hook_node_key("D"), InsertionPointGraph.get_pre_hook_node_key("B"), InsertionPointGraph.get_post_hook_node_key("B")), ] @staticmethod @pytest.fixture(params=START_TARGET_ACCEPTING_NODES) def start_target_accepting_nodes(request): return request.param @pytest.mark.dependency(depends="propagate_via_path") def test_backtrack_propagation_until_accepting_location(self, start_target_accepting_nodes, mock_qp_graph): start_ip_node_key = start_target_accepting_nodes[0] target_ip_node_key = start_target_accepting_nodes[1] forced_last_accepting_location = start_target_accepting_nodes[2] prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) ref_affected_edges = deepcopy(prop_quant.affected_edges) # Here, the tested graph should have such a structure that there is only one path from target to start path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0] prop_quant = mock_qp_graph.propagate_quantizer_via_path(prop_quant, path) prop_quant.last_accepting_location_node_key = forced_last_accepting_location if forced_last_accepting_location is not None: resulting_path = get_edge_paths_for_propagation(mock_qp_graph, forced_last_accepting_location, start_ip_node_key)[0] ref_affected_edges.update(set(resulting_path)) prop_quant = mock_qp_graph.backtrack_propagation_until_accepting_location(prop_quant) assert prop_quant.current_location_node_key == forced_last_accepting_location assert prop_quant.affected_edges == ref_affected_edges assert prop_quant.propagation_path == resulting_path target_node = mock_qp_graph.nodes[target_ip_node_key] accepting_node = mock_qp_graph.nodes[forced_last_accepting_location] if forced_last_accepting_location != target_ip_node_key: assert target_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None assert target_ip_node_key not in prop_quant.affected_ip_nodes assert accepting_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] == prop_quant @pytest.mark.dependency(depends="propagate_via_path") def test_clone_propagating_quantizer(self, mock_qp_graph, start_target_nodes): start_ip_node_key = start_target_nodes[0] target_ip_node_key = start_target_nodes[1] # From "target" to "start" since propagation direction is inverse to edge direction # Only take one path out of possible paths for this test rev_path = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key, start_ip_node_key)[0] ref_prop_quant = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key) prop_quant = mock_qp_graph.propagate_quantizer_via_path(ref_prop_quant, rev_path) cloned_prop_quant = mock_qp_graph.clone_propagating_quantizer(prop_quant) assert cloned_prop_quant.affected_ip_nodes == prop_quant.affected_ip_nodes assert cloned_prop_quant.affected_edges == prop_quant.affected_edges assert cloned_prop_quant.propagation_path == prop_quant.propagation_path assert cloned_prop_quant.current_location_node_key == prop_quant.current_location_node_key for ip_node_key in prop_quant.affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert cloned_prop_quant in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in prop_quant.affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert cloned_prop_quant in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] # The cloned quantizer had not been put into any IP (cannot have multiple PQs in one IP right now) mock_qp_graph.skip_check = True START_TARGET_NODES_FOR_TWO_QUANTIZERS = [ (InsertionPointGraph.get_pre_hook_node_key("E"), InsertionPointGraph.get_post_hook_node_key("C"), InsertionPointGraph.get_pre_hook_node_key("H"), InsertionPointGraph.get_post_hook_node_key("G")), (InsertionPointGraph.get_pre_hook_node_key("C"), InsertionPointGraph.get_post_hook_node_key("A"), InsertionPointGraph.get_pre_hook_node_key("H"), InsertionPointGraph.get_pre_hook_node_key("D")), # Simulated quantizer merging result (InsertionPointGraph.get_pre_hook_node_key("G"), InsertionPointGraph.get_pre_hook_node_key("E"), InsertionPointGraph.get_pre_hook_node_key("G"), InsertionPointGraph.get_post_hook_node_key("D")) ] @staticmethod @pytest.fixture(params=START_TARGET_NODES_FOR_TWO_QUANTIZERS) def start_target_nodes_for_two_quantizers(request): return request.param @pytest.mark.dependency(depends="propagate_via_path") def test_remove_propagating_quantizer(self, mock_qp_graph, start_target_nodes_for_two_quantizers): start_ip_node_key_remove = start_target_nodes_for_two_quantizers[0] target_ip_node_key_remove = start_target_nodes_for_two_quantizers[1] start_ip_node_key_keep = start_target_nodes_for_two_quantizers[2] target_ip_node_key_keep = start_target_nodes_for_two_quantizers[3] # From "target" to "start" since propagation direction is inverse to edge direction # Only take one path out of possible paths for this test rev_path_remove = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key_remove, start_ip_node_key_remove)[0] rev_path_keep = get_edge_paths_for_propagation(mock_qp_graph, target_ip_node_key_keep, start_ip_node_key_keep)[0] prop_quant_to_remove = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key_remove) prop_quant_to_remove = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_remove, rev_path_remove) prop_quant_to_keep = mock_qp_graph.add_propagating_quantizer([QuantizerConfig()], start_ip_node_key_keep) prop_quant_to_keep = mock_qp_graph.propagate_quantizer_via_path(prop_quant_to_keep, rev_path_keep) affected_ip_nodes = deepcopy(prop_quant_to_remove.affected_ip_nodes) affected_op_nodes = deepcopy(prop_quant_to_remove.affected_operator_nodes) affected_edges = deepcopy(prop_quant_to_keep.affected_edges) last_location = prop_quant_to_remove.current_location_node_key ref_quant_to_keep_state_dict = deepcopy(prop_quant_to_keep.__dict__) mock_qp_graph.remove_propagating_quantizer(prop_quant_to_remove) last_node = mock_qp_graph.nodes[last_location] assert last_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] is None for ip_node_key in affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for op_node_key in affected_op_nodes: node = mock_qp_graph.nodes[op_node_key] assert prop_quant_to_remove not in node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert prop_quant_to_remove not in edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] assert prop_quant_to_keep.__dict__ == ref_quant_to_keep_state_dict for ip_node_key in prop_quant_to_keep.affected_ip_nodes: node = mock_qp_graph.nodes[ip_node_key] assert prop_quant_to_keep in node[ QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] for from_node_key, to_node_key in prop_quant_to_keep.affected_edges: edge = mock_qp_graph.edges[from_node_key, to_node_key] assert prop_quant_to_keep in edge[ QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES = [ (["D", "E", "F"], { "B": {"D", "E"}, InsertionPointGraph.get_pre_hook_node_key("B"): {"D", "E"}, "E": {"F"}, InsertionPointGraph.get_post_hook_node_key("D"): {"F"}, "A": {"D", "E"}, InsertionPointGraph.get_pre_hook_node_key("G"): set() }), (["C", "E", "H"], { InsertionPointGraph.get_pre_hook_node_key("C"): {"C"}, InsertionPointGraph.get_post_hook_node_key("C"): {"E"}, "D": {"H"}, InsertionPointGraph.get_pre_hook_node_key("B"): {"C", "H"}, # corner case - has a branch without quantizers InsertionPointGraph.get_post_hook_node_key("H"): set() }) ] @staticmethod @pytest.fixture(params=QUANTIZABLE_NODES_START_NODES_DOMINATED_NODES) def dominated_nodes_test_struct(request): return request.param @staticmethod def mark_nodes_with_traits(qpsg: QPSG, node_keys_vs_traits_dict: Dict[str, QuantizationTrait]) -> QPSG: for node_key, node in qpsg.nodes.items(): if node_key in node_keys_vs_traits_dict: node[QPSG.QUANTIZATION_TRAIT_NODE_ATTR] = node_keys_vs_traits_dict[node_key] return qpsg def test_get_quantizable_op_nodes_immediately_dominated_by_node(self, mock_qp_graph, dominated_nodes_test_struct): nodes_to_mark_as_quantizable = dominated_nodes_test_struct[0] node_keys_vs_trait_dict = {} for node_key in mock_qp_graph.nodes: node_keys_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC traits_to_mark_with = [QuantizationTrait.INPUTS_QUANTIZABLE, QuantizationTrait.NON_QUANTIZABLE] for trait in traits_to_mark_with: for node_key in nodes_to_mark_as_quantizable: node_keys_vs_trait_dict[node_key] = trait mock_qp_graph = self.mark_nodes_with_traits(mock_qp_graph, node_keys_vs_trait_dict) for start_node_key, ref_dominated_quantizable_nodes_set in dominated_nodes_test_struct[1].items(): dominated_quantizable_nodes_list = \ mock_qp_graph.get_non_quant_agnostic_op_nodes_immediately_dominated_by_node(start_node_key) assert set(dominated_quantizable_nodes_list) == ref_dominated_quantizable_nodes_set @staticmethod def get_model_graph(): mock_node_attrs = get_mock_nncf_node_attrs() mock_graph = nx.DiGraph() # (A) # | # (B) # / \ # (C) (D) # | | # (F) (E) # # node_keys = ['A', 'B', 'C', 'D', 'E', 'F'] for node_key in node_keys: mock_graph.add_node(node_key, **mock_node_attrs) mock_graph.add_edges_from([('A', 'B'), ('B', 'C'), ('B', 'D'), ('D', 'E'), ('C', 'F')]) mark_input_ports_lexicographically_based_on_input_node_key(mock_graph) return mock_graph StateQuantizerTestStruct = namedtuple('StateQuantizerTestStruct', ('init_node_to_trait_and_configs_dict', 'starting_quantizer_ip_node', 'target_node_for_quantizer', 'is_merged', 'prop_path')) SetQuantizersTestStruct = namedtuple('SetQuantizersTestStruct', ('start_set_quantizers', 'expected_set_quantizers')) MERGE_QUANTIZER_INTO_PATH_TEST_CASES = [ SetQuantizersTestStruct( start_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'E': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'), target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'), is_merged=False, prop_path=None ), StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'F': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]), }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'), target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'), is_merged=True, prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'), InsertionPointGraph.get_pre_hook_node_key('C'))] ) ], expected_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'PRE HOOK B': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=['E', 'F'], target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'), is_merged=False, prop_path=None ) ] ), SetQuantizersTestStruct( start_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'E': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'), target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'), is_merged=False, prop_path=None ), StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'F': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]), }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'), target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'), is_merged=True, prop_path=[('B', InsertionPointGraph.get_post_hook_node_key('B'))] ) ], expected_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'B': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=['E', 'F'], target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('B'), is_merged=False, prop_path=None ) ] ), SetQuantizersTestStruct( start_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'E': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('E'), target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'), is_merged=False, prop_path=None ), StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'F': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]), }, starting_quantizer_ip_node=InsertionPointGraph.get_pre_hook_node_key('F'), target_node_for_quantizer=InsertionPointGraph.get_pre_hook_node_key('C'), is_merged=True, prop_path=[(InsertionPointGraph.get_post_hook_node_key('B'), InsertionPointGraph.get_pre_hook_node_key('C'))] ) ], expected_set_quantizers=[ StateQuantizerTestStruct( init_node_to_trait_and_configs_dict= { 'B': (QuantizationTrait.INPUTS_QUANTIZABLE, [QuantizerConfig()]) }, starting_quantizer_ip_node=['E', 'F'], target_node_for_quantizer=InsertionPointGraph.get_post_hook_node_key('B'), is_merged=False, prop_path=None ) ] ) ] @staticmethod @pytest.fixture(params=MERGE_QUANTIZER_INTO_PATH_TEST_CASES) def merge_quantizer_into_path_test_struct(request): return request.param @pytest.fixture def model_graph_qpsg(self): mock_graph = self.get_model_graph() ip_graph = InsertionPointGraph(mock_graph) quant_prop_graph = QPSG(ip_graph) return quant_prop_graph def test_merge_quantizer_into_path(self, model_graph_qpsg, merge_quantizer_into_path_test_struct): quant_prop_graph = model_graph_qpsg for quantizers_test_struct in merge_quantizer_into_path_test_struct.start_set_quantizers: init_node_to_trait_and_configs_dict = quantizers_test_struct.init_node_to_trait_and_configs_dict starting_quantizer_ip_node = quantizers_test_struct.starting_quantizer_ip_node target_node = quantizers_test_struct.target_node_for_quantizer is_merged = quantizers_test_struct.is_merged prop_path = quantizers_test_struct.prop_path node_key_vs_trait_dict = {} # type: Dict[str, QuantizationTrait] for node_key in quant_prop_graph.nodes: node_key_vs_trait_dict[node_key] = QuantizationTrait.QUANTIZATION_AGNOSTIC primary_prop_quant = None merged_prop_quant = [] for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items(): trait = trait_and_configs_tuple[0] node_key_vs_trait_dict[node_key] = trait quant_prop_graph = self.mark_nodes_with_traits(quant_prop_graph, node_key_vs_trait_dict) for node_key, trait_and_configs_tuple in init_node_to_trait_and_configs_dict.items(): trait = trait_and_configs_tuple[0] qconfigs = trait_and_configs_tuple[1] if trait == QuantizationTrait.INPUTS_QUANTIZABLE: ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key) prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs, ip_node_key) if ip_node_key == starting_quantizer_ip_node: primary_prop_quant = prop_quant path = get_edge_paths_for_propagation(quant_prop_graph, target_node, starting_quantizer_ip_node) primary_prop_quant = quant_prop_graph.propagate_quantizer_via_path(primary_prop_quant, path[0]) if is_merged: merged_prop_quant.append((primary_prop_quant, prop_path)) quant_prop_graph.run_consistency_check() for prop_quant, prop_path in merged_prop_quant: quant_prop_graph.merge_quantizer_into_path(prop_quant, prop_path) quant_prop_graph.run_consistency_check() expected_quantizers_test_struct = merge_quantizer_into_path_test_struct.expected_set_quantizers self.check_final_state_qpsg(quant_prop_graph, expected_quantizers_test_struct) @staticmethod def check_final_state_qpsg(final_quant_prop_graph, expected_quantizers_test_struct): for quantizer_param in expected_quantizers_test_struct: from_node_key = quantizer_param.target_node_for_quantizer expected_prop_path = set() target_node = quantizer_param.target_node_for_quantizer for start_node in quantizer_param.starting_quantizer_ip_node: added_path = get_edge_paths_for_propagation(final_quant_prop_graph, target_node, start_node) expected_prop_path.update(added_path[0]) quantizer = final_quant_prop_graph.nodes[from_node_key][QPSG.PROPAGATING_QUANTIZER_NODE_ATTR] assert quantizer is not None for from_node_key, to_node_key in expected_prop_path: assert quantizer in final_quant_prop_graph.edges[(from_node_key, to_node_key)][ QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] from_node = final_quant_prop_graph.nodes[from_node_key] from_node_type = from_node[QPSG.NODE_TYPE_NODE_ATTR] if from_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: # pylint:disable=line-too-long assert quantizer in final_quant_prop_graph.nodes[from_node_key][QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] assert quantizer.affected_edges == expected_prop_path
if __name__ == '__main__': for input_name, input_size, gpu_runs in TEST_PARAMS_STRUCT: print("CUDA " + input_name) print("------------------------------------------------") print("Pytorch Symmetric (cuda 0) impl:") print("input size: {0}".format(input_size)) run_profile( ReferenceQuantize(NBITS).cuda(), input_size, 'cuda', gpu_runs) print() print("Custom Symmetric (cuda 0 ) impl:") print("input size: {0}".format(input_size)) run_profile( SymmetricQuantizer(QuantizerConfig(bits=NBITS)).cuda(), input_size, 'cuda', gpu_runs) print() print("Pytorch Symmetric Per Weight Channel (cuda 0) impl:") print("input size: {0}".format(input_size)) run_profile( ReferenceQuantize(NBITS, input_shape=input_size, per_channel=True, is_weights=True).cuda(), input_size, 'cuda', gpu_runs) print() print("Custom Symmetric Per Weight Channel (cuda 0 ) impl") print("input size: {0}".format(input_size))
def create_quantize_module(_, is_weights=False, input_shape=None): return SymmetricQuantizer(QuantizerConfig(signedness_to_force=is_weights, is_weights=is_weights))
class QuantizerPropagationSolver: """Analyzes a fresh QuantizerPropagationStateGraph object according to HW configuration supplied in the initializer and produces the list of insertion commands that correspond to the final state of the quantizer propagation graph when the model has the most contol flow graph edges quantized according to HW capabilities.""" DEFAULT_QUANTIZATION_TYPES = [ QuantizerConfig(bits=8, mode=QuantizationMode.SYMMETRIC, signedness_to_force=None, per_channel=False) ] def __init__( self, ignored_scopes=None, hw_config=None, debug_interface: 'QuantizationDebugInterface' = None, propagation_strategy: PropagationStrategy = PropagationStrategy. AGGRESSIVE): self._hw_config = hw_config self._debug_interface = debug_interface self._propagation_strategy = propagation_strategy # TODO: determine from config self._operator_quantization_trait_map = self.get_operator_quantization_traits_map( ) self._operator_allowed_qconfigs_map = self._get_operator_qconfigs_map() self._active_propagating_quantizers_queue = deque() self._finished_propagating_quantizers = [ ] # type: List[PropagatingQuantizer] self._ignored_scopes = ignored_scopes def run_on_ip_graph( self, ip_graph: InsertionPointGraph ) -> Dict[InsertionInfo, Optional[List[QuantizerConfig]]]: """ The main function to be used on an InsertionPointGraph to produce the list of insertion commands and configs corresponding to the final quantized graph state.""" quant_prop_graph = QuantizerPropagationStateGraph( ip_graph, self._ignored_scopes) quant_prop_graph = self.set_allowed_quantization_types_for_operator_nodes( quant_prop_graph) quant_prop_graph = self.setup_initial_quantizers(quant_prop_graph) iteration_counter = 0 while self._active_propagating_quantizers_queue: prop_quantizer = self._active_propagating_quantizers_queue.pop() if self._debug_interface is not None: self._debug_interface.visualize_quantizer_propagation( self, quant_prop_graph, str(iteration_counter)) quant_prop_graph = self.propagation_step(prop_quantizer, quant_prop_graph) iteration_counter += 1 if self._debug_interface is not None: self._debug_interface.visualize_quantizer_propagation( self, quant_prop_graph, "final") retval = {} self._potential_quantizers = {} for finished_prop_quantizer in self._finished_propagating_quantizers: final_node_key = finished_prop_quantizer.current_location_node_key final_node = quant_prop_graph.nodes[final_node_key] insertion_point = final_node[ QuantizerPropagationStateGraph. INSERTION_POINT_DATA_NODE_ATTR] # type: InsertionPoint insertion_info = InsertionInfo( OperationExecutionContext( operator_name=insertion_point.ia_op_exec_context. operator_name, scope_in_model=insertion_point.ia_op_exec_context. scope_in_model, call_order=insertion_point.ia_op_exec_context.call_order, tensor_metas=[None]) ) # TODO: fix this, rethink InsertionInfo here and elsewhere self._potential_quantizers[ insertion_point] = finished_prop_quantizer.potential_quant_configs retval[ insertion_info] = finished_prop_quantizer.potential_quant_configs return retval def propagation_step( self, curr_prop_quantizer: PropagatingQuantizer, quant_prop_graph: QuantizerPropagationStateGraph ) -> QuantizerPropagationStateGraph: """Returns an updated curr_prop_quantizer state if the quantizer is not yet in its final (accepting) position, and None if the quantizer is in its final location. The location before and after the step should correspond to some insertion point.""" # TODO: full-fledged discrete finite automata approach? Switch to traversing a graph # consisting of insertion points only, with reversed edges holding associated operator data? curr_node_key = curr_prop_quantizer.current_location_node_key curr_node = quant_prop_graph.nodes[ curr_prop_quantizer.current_location_node_key] curr_node_type = curr_node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] assert curr_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT # Assumption: paths are at most 2 edges - either one edge to neighbouring insertion point # or one edge to operation and next edge to its own neighbouring insertion point. paths = quant_prop_graph.get_paths_to_immediately_dominating_insertion_points( curr_node_key) if not paths: prop_quantizer = quant_prop_graph.backtrack_propagation_until_accepting_location( curr_prop_quantizer) if prop_quantizer is not None: self._finished_propagating_quantizers.append(prop_quantizer) return quant_prop_graph surviving_prop_quantizers = [] prop_quantizers_to_process = [curr_prop_quantizer] for _ in range(1, len(paths)): additional_prop_quantizer = quant_prop_graph.clone_propagating_quantizer( curr_prop_quantizer) prop_quantizers_to_process.append(additional_prop_quantizer) pqs_and_paths = zip(paths, prop_quantizers_to_process) for path, prop_quantizer in pqs_and_paths: status = self.check_transition_via_path(prop_quantizer, path, quant_prop_graph) if status == TransitionStatus.SHOULD_NOT_TRANSITION: prop_quantizer = quant_prop_graph.backtrack_propagation_until_accepting_location( prop_quantizer) if prop_quantizer is not None: self._finished_propagating_quantizers.append( prop_quantizer) elif status == TransitionStatus.SHOULD_TRANSITION: prop_quantizer = quant_prop_graph.propagate_quantizer_via_path( prop_quantizer, path) surviving_prop_quantizers.append(prop_quantizer) elif status == TransitionStatus.SHOULD_MERGE: # The surviving quantizer will have its "affected edges" set extended # by the corresponding set of the merged quantizer. The assumption # here is that the surviving quantizer should never have to cross # such a "merge point" while backtracking to an accepting location. quant_prop_graph.merge_quantizer_into_path( prop_quantizer, path) for prop_quantizer in surviving_prop_quantizers: self._active_propagating_quantizers_queue.appendleft( prop_quantizer) return quant_prop_graph def get_allowed_quantizer_configs_for_operator( self, quant_det_id: OperatorMetatype) -> List[QuantizerConfig]: return self._operator_allowed_qconfigs_map[quant_det_id] def set_allowed_quantization_types_for_operator_nodes( self, quant_prop_graph: QuantizerPropagationStateGraph): for node_key, node in quant_prop_graph.nodes.items(): node_type = node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: quant_det_id = node[ QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR] if quant_det_id is None: warnings.warn( "Unknown metatype for operator node: {}".format( node_key)) trait = QuantizationTrait.QUANTIZATION_AGNOSTIC else: trait = self._operator_quantization_trait_map[quant_det_id] node[QuantizerPropagationStateGraph. QUANTIZATION_TRAIT_NODE_ATTR] = trait if trait == QuantizationTrait.INPUTS_QUANTIZABLE: node[QuantizerPropagationStateGraph.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] = \ self.get_allowed_quantizer_configs_for_operator(quant_det_id) return quant_prop_graph def get_operator_quantization_traits_map( self) -> Dict[OperatorMetatype, QuantizationTrait]: # TODO: ensure that there are no name collisions between ops in different torch subpackages with the same name retval = {} if self._hw_config is None: for op_meta in OPERATOR_METATYPES.registry_dict.values(): retval[ op_meta] = QuantizationTrait.QUANTIZATION_AGNOSTIC # Default value for trait, meta_list in DEFAULT_QUANT_TRAIT_TO_OP_DICT.items(): for op_meta in meta_list: # type: OperatorMetatype retval[op_meta] = trait else: op_meta_vs_qconfs_map = self._hw_config.get_metatype_vs_quantizer_configs_map( ) for op_meta, qconf_list in op_meta_vs_qconfs_map.items(): if qconf_list is None: trait = QuantizationTrait.QUANTIZATION_AGNOSTIC elif qconf_list: trait = QuantizationTrait.INPUTS_QUANTIZABLE else: trait = QuantizationTrait.NON_QUANTIZABLE retval[op_meta] = trait return retval def _get_operator_qconfigs_map( self) -> Dict[OperatorMetatype, List[QuantizerConfig]]: # TODO: ensure that there are no name collisions between ops in different torch subpackages with the same name retval = {} if self._hw_config is None: for op_meta in OPERATOR_METATYPES.registry_dict.values(): retval[ op_meta] = QuantizationTrait.QUANTIZATION_AGNOSTIC # Default value for trait, meta_list in DEFAULT_QUANT_TRAIT_TO_OP_DICT.items(): if trait == QuantizationTrait.INPUTS_QUANTIZABLE: for op_meta in meta_list: # type: OperatorMetatype retval[op_meta] = self.DEFAULT_QUANTIZATION_TYPES else: for op_meta in meta_list: # type: OperatorMetatype retval[op_meta] = [] else: retval = self._hw_config.get_metatype_vs_quantizer_configs_map() return retval def debug_visualize(self, quant_prop_graph: QuantizerPropagationStateGraph, dump_path: str): out_graph = quant_prop_graph.get_visualized_graph() active_ids_str = ", ".join( [str(pq.id) for pq in self._active_propagating_quantizers_queue]) finished_ids_str = ", ".join( [str(pq.id) for pq in self._finished_propagating_quantizers]) out_graph.graph['graph'] = {"label": "Propagating quantizers: {}\n" \ "Finished quantizers: {}".format(active_ids_str, finished_ids_str), "labelloc": "t"} nx.drawing.nx_pydot.write_dot(out_graph, dump_path) def setup_initial_quantizers( self, quant_prop_graph: QuantizerPropagationStateGraph ) -> QuantizerPropagationStateGraph: """Determines the initial subset of the nodes that must be quantized and corresponding allowed quantization configs (possibly multiple) for each quantizer.""" for node_key, node in quant_prop_graph.nodes.items(): node_type = node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: if node_key in quant_prop_graph.ignored_node_keys: continue preds = list(quant_prop_graph.predecessors(node_key)) if not preds: continue # TODO: remove this once module insertion points are included in the IP graph # Should be immediately preceded by an insertion point. pred_ip_key = preds[0] pred_node = quant_prop_graph.nodes[pred_ip_key] pred_node_type = pred_node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] assert pred_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT, \ "Invalid insertion point graph supplied for quantizer propagation!" if node[QuantizerPropagationStateGraph. QUANTIZATION_TRAIT_NODE_ATTR] in [ QuantizationTrait.NON_QUANTIZABLE, QuantizationTrait.QUANTIZATION_AGNOSTIC ]: continue quant_det_id = node[ QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR] qconf_list = self.get_allowed_quantizer_configs_for_operator( quant_det_id) prop_quantizer = quant_prop_graph.add_propagating_quantizer( qconf_list, pred_ip_key) self._active_propagating_quantizers_queue.appendleft( prop_quantizer) return quant_prop_graph def check_branching_transition( self, quant_prop_graph: QuantizerPropagationStateGraph, prop_quantizer: PropagatingQuantizer, branching_node_key: str) -> Optional[TransitionStatus]: """If a propagating quantizer advances through a node that branches downwards, the branches neighbouring to the one that the propagating quantizer had just propagated from will have the precision of the quantizer imposed upon them. This is not always desirable - we might want to keep some branches in higher precision than the others. For this reason, this function checks whether the quantizer may safely advance through a branching node based on the possible configs of the quantizers it might affect by doing so.""" dom_op_node_keys = quant_prop_graph.get_quantizable_op_nodes_immediately_dominated_by_node( branching_node_key) master_possible_qconfigs = prop_quantizer.potential_quant_configs slave_possible_qconfigs_dict = {} for op_node_key in dom_op_node_keys: op_node = quant_prop_graph.nodes[op_node_key] affecting_prop_quantizers = op_node[ QuantizerPropagationStateGraph. AFFECTING_PROPAGATING_QUANTIZERS_ATTR] if not affecting_prop_quantizers: # The branch op is forced to be FP32 - should not proceed through the branch node. return TransitionStatus.SHOULD_NOT_TRANSITION slave_possible_qconfigs = affecting_prop_quantizers[ 0].potential_quant_configs slave_possible_qconfigs_dict[op_node_key] = slave_possible_qconfigs master_merged_qconfigs, \ slave_merged_qconfigs_dict = self.get_merged_qconfigs(master_possible_qconfigs, slave_possible_qconfigs_dict) if not master_merged_qconfigs: # This quantizer's precision does not encompass the precisions of quantizers # propagating through downward branches. return TransitionStatus.SHOULD_NOT_TRANSITION if self._propagation_strategy == PropagationStrategy.CONSERVATIVE: for op_node_key, slave_merged_qconfigs_list in slave_merged_qconfigs_dict.items( ): if len(slave_possible_qconfigs_dict[op_node_key]) != len( slave_merged_qconfigs_list): return TransitionStatus.SHOULD_NOT_TRANSITION return None def check_transition_via_path( self, prop_quantizer: PropagatingQuantizer, path: List, quant_prop_graph: QuantizerPropagationStateGraph ) -> TransitionStatus: """Determines which action should be taken regarding the prop_quantizer's propagation via path, which may be one of many possible propagation paths.""" for from_node_key, to_node_key in path: from_node = quant_prop_graph.nodes[from_node_key] if len(list(quant_prop_graph.successors(from_node_key))) > 1: # If a quantizer simply passes up through a downward-branching node, it may spoil the # precision for operations on neighbouring branches. Consider a 4-bit quantizer rising # through a branch node and an 8-bit quantizer arriving at the same node later. Therefore, # prior to allowing the quantizer to pass through a branching node we need to ensure that # the precision of the quantizer is a superset of precisions of the first non-quantization agnostic # operations on each branch. status = self.check_branching_transition( quant_prop_graph, prop_quantizer, from_node_key) if status is not None: return status from_node_type = from_node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if from_node_type == QuantizerPropagationStateGraphNodeType.OPERATOR: trait = from_node[QuantizerPropagationStateGraph. QUANTIZATION_TRAIT_NODE_ATTR] if trait in [ QuantizationTrait.NON_QUANTIZABLE, QuantizationTrait.INPUTS_QUANTIZABLE ]: return TransitionStatus.SHOULD_NOT_TRANSITION edge = quant_prop_graph.edges[from_node_key, to_node_key] potential_quantizers = edge[QuantizerPropagationStateGraph. AFFECTING_PROPAGATING_QUANTIZERS_ATTR] if potential_quantizers: # Assuming that multiple affecting quantizers should all have the same quantization config # by construction if prop_quantizer.potential_quant_configs == potential_quantizers[ 0].potential_quant_configs: return TransitionStatus.SHOULD_MERGE return TransitionStatus.SHOULD_NOT_TRANSITION from_node_type = from_node[ QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR] if from_node_type == QuantizerPropagationStateGraphNodeType.INSERTION_POINT: potential_quantizers = from_node[ QuantizerPropagationStateGraph. AFFECTING_PROPAGATING_QUANTIZERS_ATTR] if potential_quantizers: # Affecting quantizers should have the same configs by construction, so we only # check the first if prop_quantizer.potential_quant_configs == potential_quantizers[ 0].potential_quant_configs: return TransitionStatus.SHOULD_MERGE return TransitionStatus.SHOULD_TRANSITION def get_merged_qconfigs( self, master_potential_qconfigs_list: List[QuantizerConfig], slave_potential_qconfigs_dict: Dict[str, List[QuantizerConfig]] ) -> Tuple[List[QuantizerConfig], Dict[str, QuantizerConfig]]: """Returns potential qconfigs lists for 'master' and 'slave' quantizers that are compatible with each other. Compatibility is decided in terms of master quantizer having configs which all have higher precision than all the slave potential quantizer configs.""" final_master_merged_qconfigs_list = deepcopy( master_potential_qconfigs_list) curr_slave_merged_qconfigs_dict = deepcopy( slave_potential_qconfigs_dict) # TODO: implement variant solutions, i.e. for each set of resultant merged # master potential qconfig lists we have, in general, different merged slave potential # config lists. Currently greedy approach is used. for m_qconfig in master_potential_qconfigs_list: should_persist_slave_merged_qconfigs_dict = True candidate_slave_merged_qconfigs_dict = deepcopy( curr_slave_merged_qconfigs_dict) for node_key, s_qconfig_list in curr_slave_merged_qconfigs_dict.items( ): for s_qconfig in s_qconfig_list: if m_qconfig < s_qconfig and s_qconfig in candidate_slave_merged_qconfigs_dict[ node_key]: candidate_slave_merged_qconfigs_dict[node_key].remove( s_qconfig) for _, s_qconfig_list in candidate_slave_merged_qconfigs_dict.items( ): if not s_qconfig_list: # No options left for slave configs on one of the branches to accomodate the master # config - this master config cannot be used to be merged into. final_master_merged_qconfigs_list.remove(m_qconfig) should_persist_slave_merged_qconfigs_dict = False break if should_persist_slave_merged_qconfigs_dict: curr_slave_merged_qconfigs_dict = candidate_slave_merged_qconfigs_dict if not final_master_merged_qconfigs_list: return [], {} return final_master_merged_qconfigs_list, curr_slave_merged_qconfigs_dict def get_finished_propagating_quantizers(self): return self._finished_propagating_quantizers def get_active_propagating_quantizers_queue(self): return self._active_propagating_quantizers_queue