def add_auto_convert(module: torch.nn.Module) -> torch.nn.Module: def convert_to_dispatch_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationConvertTensorProxy) # type: ignore[arg-type] else: return x module_id_to_fqn: Dict[int, str] = {} class QuantizationConvertTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic dispatch for quantization inference. For each function with a `__torch_fuction__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_convert_before_hook`. 3. executes the function, with target, args and kwargs possibly modified by (2) 4. calls `_auto_quant_state.inference_function_after_hook`. 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): # to prevent printing things from going into an infinite loop if func == torch.Tensor.__repr__: return super().__torch_function__(func, types, args, kwargs) kwargs = kwargs if kwargs else {} # if we are in a function, the current module is always a parent parent_module = cur_module hook_type = get_torch_function_hook_type(parent_module, func) if enable_logging: with torch._C.DisableTorchFunction(): logger.debug( f"__torch_function__ {func} " + f"hook_type {hook_type} " + # f"arg_types {[type(arg) for arg in args]}) " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}" ) if hook_type is HookType.OP_HOOKS: assert parent_module is not None qstate = parent_module._auto_quant_state # before hooks qstate.validate_cur_op(func) func, args, kwargs = qstate.op_convert_before_hook( func, args, kwargs, parent_module) # forward output = super().__torch_function__(func, types, args, kwargs) # after hooks output = qstate.op_convert_after_hook(func, output) qstate.mark_cur_op_complete(func) elif hook_type is HookType.ARG_DEQUANTS: # disabling torch function to prevent infinite recursion on # getset # TODO(future PR): handle more dtypes with torch._C.DisableTorchFunction(): new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: new_args.append(arg.dequantize()) else: new_args.append(arg) args = tuple(new_args) output = super().__torch_function__(func, types, args, kwargs) else: # HookType.NONE output = super().__torch_function__(func, types, args, kwargs) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationConvertTensorProxy) assert output is not NotImplemented if enable_logging: out_dtype = None if isinstance(output, torch.Tensor): out_dtype = output.dtype logger.debug(f"__torch_function__ {func} out {out_dtype} end") return output def __repr__(self): return f'QuantizationConvertTensorProxy({super().__repr__()})' cur_module = None module_stack: List[torch.nn.Module] = [] assert len(module.__class__.__bases__) == 1 class QuantizationDispatchModule(module.__class__.__bases__[0] ): # type: ignore[name-defined] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization, after model conversion to quantized domain. `cur_module` keeps track of the current module in the stack. Tensor arguments are converted to `QuantizationConvertTensorProxy`. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_convert_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_convert_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_convert_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_convert_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_dispatch_proxy) new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nstarting fqn {fqn}") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) hook_type = get_module_hook_type(parent_module, cur_module) if enable_logging: logger.debug( f"_patched_module_call {type(self)} " + # f"arg_types {[type(arg) for arg in args]} " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " + f"hook_type {hook_type}") if hook_type is HookType.OP_HOOKS: # before hooks assert parent_module is not None assert isinstance(parent_module._auto_quant_state, AutoQuantizationState) qstate = parent_module._auto_quant_state if enable_logging: logger.debug(qstate) qstate.validate_cur_op(cur_module) _, args, kwargs = qstate.op_convert_before_hook( cur_module, args, kwargs, cur_module) # forward output = orig_module_call(self, *args, **kwargs) # after hooks output = qstate.op_convert_after_hook( cur_module, output) qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: cur_qstate = cur_module._auto_quant_state if enable_logging: logger.debug(cur_qstate) cur_qstate.validate_is_at_first_idx() # before hooks (TODO) # forward output = orig_module_call(self, *args, **kwargs) # after hooks assert isinstance(cur_qstate, AutoQuantizationState) output = cur_qstate.outputs_convert_hook(output) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: # disabling torch function to prevent infinite recursion on # getset # TODO(future PR): handle more dtypes with torch._C.DisableTorchFunction(): new_args = [] for arg in args: if isinstance( arg, torch.Tensor) and arg.is_quantized: dequant = arg.dequantize().as_subclass( QuantizationConvertTensorProxy ) # type: ignore[arg-type] new_args.append(dequant) else: new_args.append(arg) args = tuple(new_args) output = orig_module_call(self, *args, **kwargs) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: logger.debug( f"_patched_module_call {type(self)} " + # f"out {type(output)} " + f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " + "end") logger.debug(f"ending fqn {fqn}\n") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] try: for k, v in self.named_modules(): module_id_to_fqn[id(v)] = k if hasattr(v, '_auto_quant_state'): v._auto_quant_state.reset_to_new_call() needs_io_hooks = hasattr(self, '_auto_quant_state') # handle module input dtype conversions # TODO(implement) output = super().__call__(*new_args, **new_kwargs) # handle module output dtype conversions if needs_io_hooks: qstate = self._auto_quant_state assert isinstance(qstate, AutoQuantizationState) output = qstate.outputs_convert_hook(output) def unwrap_proxy(a): if isinstance(a, QuantizationConvertTensorProxy): a.__class__ = torch.Tensor # type: ignore[assignment] return a output = map_aggregate(output, unwrap_proxy) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] def rewrite_for_scripting(self): return auto_trace_rewriter.rewrite_for_scripting(self) pack_weights_for_functionals(module) module.__class__ = QuantizationDispatchModule return module
def add_auto_observation( model: torch.nn.Module, example_inputs: Tuple[Any], input_dtypes: Any = ( torch.float, ), # must be same structure as model inputs output_dtypes: Any = ( torch.float, ), # must be same structure as model outputs ) -> torch.nn.Module: def convert_to_interception_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationPrepareTensorProxy) # type: ignore[arg-type] else: return x cur_module = None first_call = True module_stack: List[torch.nn.Module] = [] # Counter for tensor IDs, will be modified inplace by quant state. # This is used to track tensors from output ops to input ops. For example, # if op_n had a tensor output with id=1, and op_n+2 had a tensor input with # id=1, we know that the output of op_n is the input to op_n+2. Note, # this is a list because it needs to incremented inplace. qtensor_id = [0] module_id_to_fqn: Dict[int, str] = {} class QuantizationPrepareTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic tracing for quantization. For each function with a `__torch_function__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_prepare_before_hook` 3. executes the original function 4. calls `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): # to prevent printing things from going into an infinite loop if func == torch.Tensor.__repr__: return super().__torch_function__(func, types, args, kwargs) if enable_logging: logger.debug( f'__torch_function__ {str(func)} len_args {len(args)}') nonlocal qtensor_id nonlocal cur_module kwargs = kwargs if kwargs else {} # if we are in a function, the current module is always a parent parent_module = cur_module hook_type = get_torch_function_hook_type(parent_module, func) if hook_type is HookType.OP_HOOKS: assert parent_module is not None qstate = parent_module._auto_quant_state fqn = module_id_to_fqn[id( parent_module)] if parent_module else None if not first_call: qstate.validate_cur_op(func) # run "before" hook args, kwargs = qstate.op_prepare_before_hook( func, args, kwargs, first_call, qtensor_id, fqn, parent_module) # forward output = super().__torch_function__(func, types, args, kwargs) # run "after" hook output = qstate.op_prepare_after_hook(func, output, args, first_call, qtensor_id, parent_module) qstate.mark_cur_op_complete(func) else: output = super().__torch_function__(func, types, args, kwargs) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationPrepareTensorProxy) assert output is not NotImplemented return output def __repr__(self): return f'QuantizationPrepareTensorProxy({super().__repr__()})' # TODO(future PR): add other math overrides class QuantizationInterceptionModule(type(model)): # type: ignore[misc] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization. `cur_module` keeps track of the current module in the stack. During the fist call, an `AutoQuantizationState` object is created and attached to each non-leaf modules which we need to check for quantizeable operations. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_prepare_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_prepare_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_prepare_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: logger.debug(f"_patched_module_call: {type(self)}") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nstarting fqn {fqn}") hook_type = get_module_hook_type(parent_module, cur_module) if hook_type is HookType.OP_HOOKS: assert parent_module is not None parent_qstate = parent_module._auto_quant_state assert isinstance(parent_qstate, AutoQuantizationState) # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs, first_call, qtensor_id, fqn, cur_module) # original forward output = orig_module_call(self, *args, **kwargs) # after hooks # TODO is it correct to call_cur_module twice here? output = parent_qstate.op_prepare_after_hook( cur_module, output, args, first_call, qtensor_id, cur_module) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.validate_is_at_first_idx() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks assert isinstance(cur_qstate, AutoQuantizationState) output = cur_qstate.outputs_prepare_hook( output, first_call, qtensor_id) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nending fqn {fqn}") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) for k, v in named_modules: # k is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = k has_qconfig = hasattr(v, 'qconfig') and v.qconfig is not None if has_qconfig and not is_leaf(v): if first_call: if v is self: # for the top level module only, specify input # and output dtypes v._auto_quant_state = AutoQuantizationState( v.qconfig, input_dtypes, output_dtypes) pass else: v._auto_quant_state = AutoQuantizationState( v.qconfig) else: if not isinstance(v, AutoQuantizationState): assert hasattr(v, '_auto_quant_state') v._auto_quant_state.reset_to_new_call() output = super().__call__(*new_args, **new_kwargs) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False model.__class__ = QuantizationInterceptionModule # create the graph trace_with_inputs(model, example_inputs) return model
def add_auto_convert(module: torch.nn.Module) -> torch.nn.Module: def convert_to_dispatch_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationConvertTensorProxy) # type: ignore[arg-type] else: return x module_id_to_fqn: Dict[int, str] = {} # Counter for global quantizeable ops, useful for intermediate activation # logging. global_op_idx = [0] global_disable_torch_function_override = False class QuantizationConvertTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic dispatch for quantization inference. For each function with a `__torch_fuction__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_convert_before_hook`. 3. executes the function, with target, args and kwargs possibly modified by (2) 4. calls `_auto_quant_state.inference_function_after_hook`. 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): nonlocal global_disable_torch_function_override if ( # global override means disable the override here global_disable_torch_function_override or # to prevent printing things from going into an infinite loop func == torch.Tensor.__repr__ or # we don't need to override getters in this framework func.__name__ == '__get__'): return super().__torch_function__(func, types, args, kwargs) kwargs = kwargs if kwargs else {} # if we are in a function, the current module is always a parent parent_module = cur_module hook_type = get_torch_function_hook_type(parent_module, func) if enable_logging: fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None logger.debug( f" fqn:{fqn_for_logging} _tf_ {func} " + f"hook_type {hook_type} " + # f"arg_types {[type(arg) for arg in args]}) " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}" ) if hook_type is HookType.OP_HOOKS: qstate: AutoQuantizationState = parent_module._auto_quant_state # type: ignore[union-attr] # before hooks qstate.validate_cur_op(func) func, args, kwargs = qstate.op_convert_before_hook( func, args, kwargs, parent_module) # type: ignore[arg-type] # forward output = super().__torch_function__(func, types, args, kwargs) # after hooks output = qstate.op_convert_after_hook(func, output, global_op_idx) qstate.mark_cur_op_complete(func) elif hook_type is HookType.ARG_DEQUANTS: # TODO(future PR): handle more dtypes new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: new_args.append(arg.dequantize()) else: new_args.append(arg) args = tuple(new_args) output = super().__torch_function__(func, types, args, kwargs) else: # HookType.NONE output = super().__torch_function__(func, types, args, kwargs) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationConvertTensorProxy) assert output is not NotImplemented if enable_logging: fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None out_dtype = None if isinstance(output, torch.Tensor): out_dtype = output.dtype logger.debug( f" fqn:{fqn_for_logging} _tf_ {func} out {out_dtype} end") return output def __repr__(self): return f'QuantizationConvertTensorProxy({super().__repr__()})' cur_module = None module_stack: List[torch.nn.Module] = [] assert len(module.__class__.__bases__) == 1 class QuantizationDispatchModule(module.__class__.__bases__[0] ): # type: ignore[name-defined] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization, after model conversion to quantized domain. `cur_module` keeps track of the current module in the stack. Tensor arguments are converted to `QuantizationConvertTensorProxy`. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_convert_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_convert_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_convert_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_convert_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_dispatch_proxy) new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): nonlocal cur_module old_module = cur_module cur_module = self nonlocal global_disable_torch_function_override try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) hook_type = get_module_hook_type(parent_module, cur_module) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " + f"hook_type {hook_type}") if hook_type is HookType.OP_HOOKS: # before hooks qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True _, args, kwargs = qstate.op_convert_before_hook( cur_module, args, kwargs, cur_module) # forward output = orig_module_call(self, *args, **kwargs) # after hooks output = qstate.op_convert_after_hook( cur_module, output, global_op_idx) # Re-enable the override. global_disable_torch_function_override = \ old_global_disable_torch_function_override qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: cur_qstate: AutoQuantizationState = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # before hooks (TODO) # forward output = orig_module_call(self, *args, **kwargs) # after hooks # For the sake of performance, we assume no overrides # are needed for quantizing/dequantizing things old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True output = cur_qstate.outputs_convert_hook(output) global_disable_torch_function_override = \ old_global_disable_torch_function_override cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: # TODO(future PR): handle more dtypes new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: dequant = arg.dequantize().as_subclass( QuantizationConvertTensorProxy ) # type: ignore[arg-type] new_args.append(dequant) else: new_args.append(arg) args = tuple(new_args) output = orig_module_call(self, *args, **kwargs) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " + "end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] try: global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) def unwrap_proxy(a): if isinstance(a, QuantizationConvertTensorProxy): a.__class__ = torch.Tensor # type: ignore[assignment] return a output = map_aggregate(output, unwrap_proxy) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] def rewrite_for_scripting(self): return auto_trace_rewriter.rewrite_for_scripting(self) pack_weights_for_functionals(module) attach_scale_zp_values_to_model(module) attach_op_convert_info_to_model(module) attach_output_convert_info_to_model(module) # Since eager mode convert could have changed the IDs of some modules, # populate the FQN map again for k, v in module.named_modules(): module_id_to_fqn[id(v)] = k module.__class__ = QuantizationDispatchModule return module
def add_auto_observation( model: torch.nn.Module, qconfig_dict: Dict[str, Any], example_inputs: Tuple[Any], input_dtypes: Any = ( torch.float, ), # must be same structure as model inputs prepare_custom_config_dict: Dict[str, Any] = None, ) -> torch.nn.Module: if prepare_custom_config_dict is None: prepare_custom_config_dict = {} output_dtypes = prepare_custom_config_dict.get('output_dtypes', (torch.float, )) def convert_to_interception_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationPrepareTensorProxy) # type: ignore[arg-type] else: return x cur_module = None first_call = True module_stack: List[torch.nn.Module] = [] # Counter for tensor IDs, will be modified inplace by quant state. # This is used to track tensors from output ops to input ops. For example, # if op_n had a tensor output with id=1, and op_n+2 had a tensor input with # id=1, we know that the output of op_n is the input to op_n+2. Note, # this is a list because it needs to incremented inplace. qtensor_id = [0] module_id_to_fqn: Dict[int, str] = {} # Counter for global quantizeable ops, useful for intermediate activation # logging. global_op_idx = [0] global_disable_torch_function_override = False class QuantizationPrepareTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic tracing for quantization. For each function with a `__torch_function__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_prepare_before_hook` 3. executes the original function 4. calls `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): nonlocal global_disable_torch_function_override if ( # global override means disable the override here global_disable_torch_function_override or # to prevent printing things from going into an infinite loop func == torch.Tensor.__repr__ or # we don't need to override getters in this framework func.__name__ == '__get__'): return super().__torch_function__(func, types, args, kwargs) # if we are in a function, the current module is always a parent nonlocal cur_module parent_module = cur_module if enable_logging: if not is_activation_post_process(parent_module): # logging for insides of obs/fq is not useful for this framework # fqn map does not contain observers, which is why we # cannot always assume that FQN exists fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None logger.debug( f' fqn:{fqn_for_logging} _tf_ {str(func)} len_args {len(args)}' ) nonlocal qtensor_id kwargs = kwargs if kwargs else {} hook_type = get_torch_function_hook_type(parent_module, func) if hook_type is HookType.OP_HOOKS: fqn = module_id_to_fqn[id( parent_module)] if parent_module else None qstate = parent_module._auto_quant_state # type: ignore[attr-defined] if not first_call: qstate.validate_cur_op(func) # run "before" hook if first_call: args, kwargs = qstate.first_call_op_prepare_before_hook( func, args, kwargs, qtensor_id, fqn, parent_module, OpQuantizeabilityType.QUANTIZEABLE) else: args, kwargs = qstate.op_prepare_before_hook( func, args, kwargs) # forward output = super().__torch_function__(func, types, args, kwargs) # run "after" hook if first_call: output = qstate.first_call_op_prepare_after_hook( func, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE) else: output = qstate.op_prepare_after_hook( func, output, args, global_op_idx) qstate.mark_cur_op_complete(func) else: # Hook type is not HookType.OP_HOOKS, if first_call is True we # record the DAG of non-quantizeable ops. if first_call: qstate = getattr(parent_module, '_auto_quant_state', None) if qstate: fqn = module_id_to_fqn.get(id(parent_module), None) \ if parent_module else None args, kwargs = qstate.first_call_op_prepare_before_hook( func, args, kwargs, qtensor_id, fqn, parent_module, OpQuantizeabilityType.NOT_QUANTIZEABLE) output = super().__torch_function__(func, types, args, kwargs) if first_call: qstate = getattr(parent_module, '_auto_quant_state', None) if qstate: output = qstate.first_call_op_prepare_after_hook( func, output, args, qtensor_id, OpQuantizeabilityType.NOT_QUANTIZEABLE) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationPrepareTensorProxy) assert output is not NotImplemented return output def __repr__(self): return f'QuantizationPrepareTensorProxy({super().__repr__()})' # TODO(future PR): add other math overrides class QuantizationInterceptionModule(type(model)): # type: ignore[misc] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization. `cur_module` keeps track of the current module in the stack. During the fist call, an `AutoQuantizationState` object is created and attached to each non-leaf modules which we need to check for quantizeable operations. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_prepare_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_prepare_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_prepare_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} start") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) hook_type = get_module_hook_type(parent_module, cur_module) if hook_type is HookType.OP_HOOKS: parent_qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. nonlocal global_disable_torch_function_override old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True if first_call: # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.first_call_op_prepare_before_hook( cur_module, args, kwargs, qtensor_id, fqn, cur_module, # type: ignore[arg-type] OpQuantizeabilityType.QUANTIZEABLE) else: # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs) # type: ignore[arg-type] # original forward output = orig_module_call(self, *args, **kwargs) # Re-enable the overrides. global_disable_torch_function_override = \ old_global_disable_torch_function_override # after hooks if first_call: output = parent_qstate.first_call_op_prepare_after_hook( cur_module, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE) else: output = parent_qstate.op_prepare_after_hook( cur_module, output, args, global_op_idx) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks if first_call: output = cur_qstate.first_call_outputs_prepare_hook( output, qtensor_id) else: output = cur_qstate.outputs_prepare_hook(output) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: if first_call and parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: args, kwargs = \ parent_qstate_fc.first_call_op_prepare_before_hook( cur_module, args, kwargs, qtensor_id, fqn, cur_module, OpQuantizeabilityType.NOT_QUANTIZEABLE) output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info if first_call and parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: output = \ parent_qstate_fc.first_call_op_prepare_after_hook( cur_module, output, args, qtensor_id, OpQuantizeabilityType.NOT_QUANTIZEABLE) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: if first_call: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) # Record module instances which are leaves or children of leaves leaves = set() for fqn, child in named_modules: if is_leaf(child, prepare_custom_config_dict): for _, child_child in child.named_modules(): leaves.add(child_child) self._fqn_to_auto_quant_state_map = AutoQuantizationStateModuleDict( ) for fqn, v in named_modules: # fqn is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = fqn if v in leaves: continue if v is self: # for the top level module only, specify input # and output dtypes auto_quant_state = AutoQuantizationState( qconfig_dict, fqn, input_dtypes, output_dtypes) else: auto_quant_state = AutoQuantizationState( qconfig_dict, fqn) # The code below registers the auto_quant_state object # of the child in the module hierarchy of the parent, # and adds the auto_quant_state object to the child # with a raw __setattr__, without registering it in # the module hierarchy of the child. # This is solving the problem of both storing extra state # (observers) as well as not modifying the meaning of user # code in child modules which iterates over all module # children. # # This narrows down the issue of dynamically adding # children to only affect the top level module and not # the children. # On the parent, register this module in the FQN map fqn_to_use_for_key = \ get_fqn_valid_for_module_dict_key(fqn) self._fqn_to_auto_quant_state_map[fqn_to_use_for_key] = \ auto_quant_state # On the child, manually set the attribute without # going through the `torch.nn.Module.__setattr__` # function, to prevent this object from appearing in # the child's module hierarchy. object.__setattr__(v, '_auto_quant_state', auto_quant_state) global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) if first_call: for _, v in self.named_modules(): if hasattr(v, '_auto_quant_state'): v._auto_quant_state.match_fusion_patterns() v._auto_quant_state.insert_observers(v) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False model.__class__ = QuantizationInterceptionModule # create the graph trace_with_inputs(model, example_inputs) return model
def get_fqn_to_example_inputs( model: torch.nn.Module, example_inputs: Tuple[Any, ...]) -> Dict[str, Tuple[Any, ...]]: """ Given a model and its example inputs, return a dictionary from fully qualified name of submodules to example_inputs for that submodule, e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), "sub.linear1": (tensor4,), ...} Used to make quantizing submodules easier now that FX Graph Mode Quantization requries example inputs. Also works for keyword arguments with default values, we would flatten keyword arguments as positional arguments and fill in the missing keyword args with default values, e.g. if we have a forward function: def forward(self, x, key1=3, key2=3): ... and we call it with self.submodule(x, key2=6) we'll get example_inputs: (x, 3, 6) user can also override `key1` with positional arguments as well: for self.submodule(x, 5, key2=6) we'll get: (x, 5, 6) variable positional arguments and variable positional keyword arguments in forward function are not supported currently, so please make sure no submodules is using them. """ root = model fqn_to_example_inputs = {} class InterceptionModule(type(model)): # type: ignore[misc] def __call__(self, *args, **kwargs): orig_module_call = torch.nn.Module.__call__ def _patched_module_call(self, *args, **kwargs): submodule_example_inputs = list(args).copy() normalized_kwargs = _normalize_kwargs(self.forward, kwargs) # minus 1 to skipping counting `self` num_args = _get_num_pos_args(self.forward) - 1 num_to_pop = num_args - len(submodule_example_inputs) while num_to_pop and normalized_kwargs: normalized_kwargs.popitem(last=False) num_to_pop -= 1 submodule_example_inputs.extend(normalized_kwargs.values()) submodule_example_inputs_tuple = tuple( submodule_example_inputs) fqn = _get_path_of_module(root, self) if fqn is not None: fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple return orig_module_call(self, *args, **kwargs) torch.nn.Module.__call__ = _patched_module_call super().__call__(*args, **kwargs) torch.nn.Module.__call__ = orig_module_call original_class = model.__class__ model.__class__ = InterceptionModule model(*example_inputs) model.__class__ = original_class return fqn_to_example_inputs
def add_nncf_functionality_to_user_module(module: torch.nn.Module): user_class = module.__class__ assert user_class.__name__ in UNWRAPPED_USER_MODULES.registry_dict module.__class__ = NNCF_WRAPPED_USER_MODULES_DICT[user_class] _NNCFModuleMixin.add_mixin_fields(module) return module