def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) run_fn(model, *run_args) convert(model, mapping, inplace=True) return model
def add_d2_quant_mapping(mappings): """HACK: Add d2 specific module mapping for eager model quantization""" import torch.quantization.quantization_mappings as qm for k, v in mappings.items(): if k not in qm.get_default_static_quant_module_mappings(): qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v if k not in qm.get_default_qat_module_mappings(): qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
def _convert(module, mapping=None, inplace=False, convert_custom_config_dict=None): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated """ if mapping is None: mapping = get_default_static_quant_module_mappings() if convert_custom_config_dict is None: convert_custom_config_dict = {} custom_module_class_mapping = convert_custom_config_dict.get( "observed_to_quantized_custom_module_class", {}) if not inplace: module = copy.deepcopy(module) reassign = {} for name, mod in module.named_children(): # both fused modules and observed custom modules are # swapped as one unit if not isinstance(mod, _FusedModule) and \ type(mod) not in custom_module_class_mapping: _convert( mod, mapping, True, # inplace convert_custom_config_dict) reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) for key, value in reassign.items(): module._modules[key] = value return module
def test_op_relationship_mapping(self): """ Tests that the mapping of op relationships is complete. """ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops( ) type_a_related_to_b = \ get_type_a_related_to_b(base_name_to_sets_of_related_ops) # 1. check static quant module mappings static_quant_mod_mappings = get_default_static_quant_module_mappings() for fp32_type, int8_type in static_quant_mod_mappings.items(): # skip quants and dequants, for the purposes of Numerical Suite types_to_skip = ( torch.quantization.QuantStub, torch.quantization.DeQuantStub, nnq.FloatFunctional, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 2. check static quant op mappings static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings( ) for fp32_type, int8_type in static_quant_fun_mappings.items(): # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 3. check dynamic quant mappings dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() for fp32_type, int8_type in dynamic_quant_mappings.items(): # TODO(future PR): enable correct weight extraction for these # and remove from this list. types_to_skip = ( nn.GRUCell, nn.GRU, nn.LSTMCell, nn.RNNCell, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 4. go through the ops mapped to each QuantizeHandler type, and verify # correctness. def _op_in_base_sets_of_related_ops(op): for name, ops in base_name_to_sets_of_related_ops.items(): if op in ops: return True return False default_quant_patterns = get_default_quant_patterns() for pattern, qhandler_cls in default_quant_patterns.items(): base_op = None if isinstance(pattern, tuple): base_op = pattern[-1] elif isinstance(pattern, str): # TODO(future PR): add handling for these continue else: base_op = pattern qhandler_cls_all_ops_quantizeable = [ qp.CatQuantizeHandler, qp.ConvReluQuantizeHandler, qp.LinearReLUQuantizeHandler, qp.BatchNormQuantizeHandler, qp.EmbeddingQuantizeHandler, qp.RNNDynamicQuantizeHandler, qp.ELUQuantizeHandler, ] qhandler_cls_quant_op_same_signature = [ qp.FixedQParamsOpQuantizeHandler, qp.CopyNodeQuantizeHandler, ] if qhandler_cls == qp.BinaryOpQuantizeHandler: # these ops do not have quantized equivalents ops_to_skip = [ torch.bmm, torch.sum, torch.div, torch.sub, operator.truediv, operator.sub ] if base_op in ops_to_skip: continue self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls == qp.RNNDynamicQuantizeHandler: # TODO(future PR): add support for all classes in # RNNDynamicQuantizeHandler pass elif qhandler_cls == qp.DefaultNodeQuantizeHandler: ops_to_skip = [ torch.nn.SiLU, torch.nn.functional.silu, ] if base_op in ops_to_skip: continue self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls in qhandler_cls_quant_op_same_signature: # these ops use the same op signature for fp32 and quantized # tensors pass elif qhandler_cls in qhandler_cls_all_ops_quantizeable: self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") else: raise AssertionError( f"handing for {qhandler_cls} not implemented")