def run_node(self, n : Node) -> Any: try: result = super().run_node(n) except Exception: traceback.print_exc() raise RuntimeError( f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}" ) found_tensor = False def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): nonlocal found_tensor found_tensor = True return _extract_tensor_metadata(obj) else: return obj meta = map_aggregate(result, extract_tensor_meta) if found_tensor: n.meta['tensor_meta'] = meta n.meta['type'] = type(result) return result
def __torch_function__(self, func, types, args=(), kwargs={}): namespace, func_name = func.split("::") func = getattr(getattr(torch.ops, namespace), func_name) outs = kwargs['val'] rets = [] proxy_args = map_aggregate( args, lambda i: i.proxy if isinstance(i, PythonTensor) else i) out_proxy = func(*proxy_args) if len(outs) == 1 and isinstance(outs[0], torch.Tensor): return [PythonTensor(outs[0], out_proxy)] for idx, out in enumerate(outs): if isinstance(out, torch.Tensor): rets.append(PythonTensor(out, out_proxy[idx])) else: rets.append(out) return rets
def run_node(self, n): result = super().run_node(n) found_tensor = False def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): nonlocal found_tensor found_tensor = True return obj else: return obj from torch.fx.node import map_aggregate concrete_value = map_aggregate(result, extract_tensor_meta) if found_tensor: n.meta['concrete_value'] = concrete_value return result
def run_node(self, n: Node) -> Any: result = super().run_node(n) found_tensor = False def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): nonlocal found_tensor found_tensor = True return extract_tensor_metadata(obj) else: return obj meta = map_aggregate(result, extract_tensor_meta) if found_tensor: n.meta['tensor_meta'] = meta n.meta['type'] = type(result) return result
def run_node(self, n: Node) -> Any: args, kwargs = self.fetch_args_kwargs_from_env(n) def get_type(arg): if isinstance(arg, fx.Node): return n.meta['type'] if 'type' in n.meta else None return type(arg) arg_types = map_aggregate(n.args, get_type) assert (isinstance(arg_types, tuple)) arg_types = tuple([create_type_hint(i) for i in arg_types]) kwarg_types = {k: get_type(v) for k, v in kwargs.items()} if n.op == 'call_function': out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) else: out = super().run_node(n) self.node_map[out] = n return out
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], dtype: dtype, is_quantized: bool, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], QUANTIZATION, } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] submodules = dict(fx_module.named_modules()) prefix = f"{name_prefix}." if name_prefix else "" def add_weight_tensors(named_tensors): for name, p in named_tensors: if name.startswith("parent.") or not isinstance(p, torch.Tensor): continue weight = serialize_weight(p) serialized_dict["weights"][prefix + name] = weight weights[prefix + name] = p add_weight_tensors(fx_module.named_parameters()) add_weight_tensors(fx_module.named_buffers()) def get_node_info(node): shape, dtype = get_shape_and_dtype(node) tensor_meta = node.meta.get('tensor_meta') if not tensor_meta: raise RuntimeError(f'Node {node} has no tensor metadata! Ensure shape ' f'propagation has been run!') node_rep = { "shape": serialize_shape(shape), "dtype": str(dtype), "is_quantized": tensor_meta.is_quantized, } if tensor_meta.is_quantized: node_rep["qscheme"] = str(tensor_meta.qscheme) if tensor_meta.qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: node_rep["q_scale"] = tensor_meta.q_scale node_rep["q_zero_point"] = tensor_meta.q_zero_point return node_rep # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module node # whose target is a GraphModule and output node. if ( not ( node.op == "call_module" and isinstance(submodules[node.target], GraphModule) ) and node.op != "output" ): node_rep.update(get_node_info(node)) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target ) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): stripped_name = node.target[len("parent.") :] node.name = stripped_name node_rep["target"] = stripped_name weight = serialize_weight(weights[stripped_name]) serialized_dict["weights"][stripped_name] = weight else: # Find the actual target parameter/buffer from the fx_module. submod_path, _, target_name = node.target.rpartition(".") submod: Optional[torch.nn.Module] = ( fx_module.get_submodule(submod_path) if submod_path else fx_module ) assert submod is not None, f"submod {submod_path} not found" target = getattr(submod, target_name, None) assert target is not None, f"{target_name} not an attr of {submod_path}" qualname = prefix + node.target # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target) serialized_dict["weights"][qualname] = weight weights[qualname] = target node_rep["op_code"] = node.op node_rep["name"] = node.name def get_arg_info(arg: Argument) -> Any: if isinstance(arg, torch.fx.Node): return {"is_node": True, "name": str(arg)} elif isinstance(arg, torch.dtype): return str(arg) else: return arg def get_output_arg_info(arg: Node) -> Dict[str, Any]: node_rep: Dict[str, Any] = get_arg_info(arg) node_rep.update(get_node_info(arg)) return node_rep if node.op == "output": node_rep["args"] = map_arg( node.args, get_output_arg_info, ) # If there're multiple outputs then node_rep["args"][0] will be a tuple. # In this case we want to unpack the tuple. if isinstance(node_rep["args"][0], tuple): node_rep["args"] = node_rep["args"][0] else: node_rep["args"] = map_aggregate( node.args, get_arg_info ) node_rep["kwargs"] = map_aggregate( node.kwargs, get_arg_info ) serialized_dict["nodes"] += [node_rep] return serialized_dict
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], stride: [], dtype: dtype, is_quantized: bool, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], QUANTIZATION, } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] submodules = dict(fx_module.named_modules()) prefix = f"{name_prefix}." if name_prefix else "" def add_weight_tensors(named_tensors): for name, p in named_tensors: if name.startswith("parent.") or not isinstance(p, torch.Tensor): continue weight_dict = serialize_weight(p, weights, prefix + name) serialized_dict["weights"].update(weight_dict) weights[prefix + name] = p add_weight_tensors(fx_module.named_parameters()) add_weight_tensors(fx_module.named_buffers()) def get_node_info(node): tensor_meta = get_tensor_meta(node) node_rep = { "shape": serialize_shape(tensor_meta.shape), "dtype": str(tensor_meta.dtype), "requires_grad": str(tensor_meta.requires_grad), "stride": serialize_stride(tensor_meta.stride), "is_quantized": tensor_meta.is_quantized, } if tensor_meta.is_quantized: node_rep["qscheme"] = str(tensor_meta.qscheme) if tensor_meta.qscheme in { torch.per_tensor_affine, torch.per_tensor_symmetric, }: node_rep["q_scale"] = tensor_meta.q_scale node_rep["q_zero_point"] = tensor_meta.q_zero_point return node_rep # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module node # whose target is a GraphModule and output node. if ( not ( node.op == "call_module" and isinstance(submodules[node.target], GraphModule) ) and node.op != "output" ): node_rep.update(get_node_info(node)) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target ) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): stripped_name = node.target[len("parent.") :] node.name = stripped_name node_rep["target"] = stripped_name weight = serialize_weight( weights[stripped_name], weights, node.target[len("parent.") :] ) # For quantized embedding tables we need to update the shape/type, # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB. user_targets = { _get_qualified_name( n.target ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n for n in node.users.keys() } if ( "acc_ops.embedding_bag_byte_rowwise_offsets" in user_targets and str( user_targets[ "acc_ops.embedding_bag_byte_rowwise_offsets" ].kwargs["weight"] ) == stripped_name ): weight[stripped_name]["dtype"] = "acc.uint8fused" # Same as above, but for the 4 bit version. if ( "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets and str( user_targets[ "acc_ops.embedding_bag_4bit_rowwise_offsets" ].kwargs["weight"] ) == stripped_name ): weight[stripped_name]["dtype"] = "acc.uint4fused" serialized_dict["weights"].update(weight) else: # Find the actual target parameter/buffer from the fx_module. submod_path, _, target_name = node.target.rpartition(".") submod: Optional[torch.nn.Module] = ( fx_module.get_submodule(submod_path) if submod_path else fx_module ) assert submod is not None, f"submod {submod_path} not found" target = getattr(submod, target_name, None) assert target is not None, f"{target_name} not an attr of {submod_path}" qualname = prefix + node.target # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target, weights, qualname) serialized_dict["weights"].update(weight) weights[qualname] = target node_rep["op_code"] = node.op node_rep["name"] = node.name def get_user_info(user_node: Argument) -> Any: return {"is_node": True, "name": str(user_node)} def get_arg_info(arg: Argument) -> Any: if isinstance(arg, torch.fx.Node): return {"is_node": True, "name": str(arg)} elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): return str(arg) else: return arg def get_output_arg_info(arg: Node) -> Dict[str, Any]: node_rep: Dict[str, Any] = get_arg_info(arg) node_rep.update(get_node_info(arg)) return node_rep if node.op == "output": node_rep["args"] = map_arg( node.args, get_output_arg_info, ) # If there're multiple outputs then node_rep["args"][0] will be a tuple. # In this case we want to unpack the tuple. if isinstance(node_rep["args"][0], tuple): node_rep["args"] = node_rep["args"][0] else: node_rep["args"] = map_aggregate(node.args, get_arg_info) node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info) node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info) serialized_dict["nodes"] += [node_rep] return serialized_dict
def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_dispatch_proxy) new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nstarting fqn {fqn}") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) hook_type = get_module_hook_type(parent_module, cur_module) if enable_logging: logger.debug( f"_patched_module_call {type(self)} " + # f"arg_types {[type(arg) for arg in args]} " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " + f"hook_type {hook_type}") if hook_type is HookType.OP_HOOKS: # before hooks assert parent_module is not None assert isinstance(parent_module._auto_quant_state, AutoQuantizationState) qstate = parent_module._auto_quant_state if enable_logging: logger.debug(qstate) qstate.validate_cur_op(cur_module) _, args, kwargs = qstate.op_convert_before_hook( cur_module, args, kwargs, cur_module) # forward output = orig_module_call(self, *args, **kwargs) # after hooks output = qstate.op_convert_after_hook( cur_module, output) qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: cur_qstate = cur_module._auto_quant_state if enable_logging: logger.debug(cur_qstate) cur_qstate.validate_is_at_first_idx() # before hooks (TODO) # forward output = orig_module_call(self, *args, **kwargs) # after hooks assert isinstance(cur_qstate, AutoQuantizationState) output = cur_qstate.outputs_convert_hook(output) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: # disabling torch function to prevent infinite recursion on # getset # TODO(future PR): handle more dtypes with torch._C.DisableTorchFunction(): new_args = [] for arg in args: if isinstance( arg, torch.Tensor) and arg.is_quantized: dequant = arg.dequantize().as_subclass( QuantizationConvertTensorProxy ) # type: ignore[arg-type] new_args.append(dequant) else: new_args.append(arg) args = tuple(new_args) output = orig_module_call(self, *args, **kwargs) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: logger.debug( f"_patched_module_call {type(self)} " + # f"out {type(output)} " + f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " + "end") logger.debug(f"ending fqn {fqn}\n") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] try: for k, v in self.named_modules(): module_id_to_fqn[id(v)] = k if hasattr(v, '_auto_quant_state'): v._auto_quant_state.reset_to_new_call() needs_io_hooks = hasattr(self, '_auto_quant_state') # handle module input dtype conversions # TODO(implement) output = super().__call__(*new_args, **new_kwargs) # handle module output dtype conversions if needs_io_hooks: qstate = self._auto_quant_state assert isinstance(qstate, AutoQuantizationState) output = qstate.outputs_convert_hook(output) def unwrap_proxy(a): if isinstance(a, QuantizationConvertTensorProxy): a.__class__ = torch.Tensor # type: ignore[assignment] return a output = map_aggregate(output, unwrap_proxy) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment]
def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: logger.debug(f"_patched_module_call: {type(self)}") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nstarting fqn {fqn}") hook_type = get_module_hook_type(parent_module, cur_module) if hook_type is HookType.OP_HOOKS: assert parent_module is not None parent_qstate = parent_module._auto_quant_state assert isinstance(parent_qstate, AutoQuantizationState) # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs, first_call, qtensor_id, fqn, cur_module) # original forward output = orig_module_call(self, *args, **kwargs) # after hooks # TODO is it correct to call_cur_module twice here? output = parent_qstate.op_prepare_after_hook( cur_module, output, args, first_call, qtensor_id, cur_module) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.validate_is_at_first_idx() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks assert isinstance(cur_qstate, AutoQuantizationState) output = cur_qstate.outputs_prepare_hook( output, first_call, qtensor_id) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f"\nending fqn {fqn}") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) for k, v in named_modules: # k is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = k has_qconfig = hasattr(v, 'qconfig') and v.qconfig is not None if has_qconfig and not is_leaf(v): if first_call: if v is self: # for the top level module only, specify input # and output dtypes v._auto_quant_state = AutoQuantizationState( v.qconfig, input_dtypes, output_dtypes) pass else: v._auto_quant_state = AutoQuantizationState( v.qconfig) else: if not isinstance(v, AutoQuantizationState): assert hasattr(v, '_auto_quant_state') v._auto_quant_state.reset_to_new_call() output = super().__call__(*new_args, **new_kwargs) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False
def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_dispatch_proxy) new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): nonlocal cur_module old_module = cur_module cur_module = self nonlocal global_disable_torch_function_override try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) hook_type = get_module_hook_type(parent_module, cur_module) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " + f"hook_type {hook_type}") if hook_type is HookType.OP_HOOKS: # before hooks qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True _, args, kwargs = qstate.op_convert_before_hook( cur_module, args, kwargs, cur_module) # forward output = orig_module_call(self, *args, **kwargs) # after hooks output = qstate.op_convert_after_hook( cur_module, output, global_op_idx) # Re-enable the override. global_disable_torch_function_override = \ old_global_disable_torch_function_override qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: cur_qstate: AutoQuantizationState = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # before hooks (TODO) # forward output = orig_module_call(self, *args, **kwargs) # after hooks # For the sake of performance, we assume no overrides # are needed for quantizing/dequantizing things old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True output = cur_qstate.outputs_convert_hook(output) global_disable_torch_function_override = \ old_global_disable_torch_function_override cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: # TODO(future PR): handle more dtypes new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: dequant = arg.dequantize().as_subclass( QuantizationConvertTensorProxy ) # type: ignore[arg-type] new_args.append(dequant) else: new_args.append(arg) args = tuple(new_args) output = orig_module_call(self, *args, **kwargs) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " + "end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] try: global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) def unwrap_proxy(a): if isinstance(a, QuantizationConvertTensorProxy): a.__class__ = torch.Tensor # type: ignore[assignment] return a output = map_aggregate(output, unwrap_proxy) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment]
def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} start") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) hook_type = get_module_hook_type(parent_module, cur_module) if hook_type is HookType.OP_HOOKS: parent_qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. nonlocal global_disable_torch_function_override old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True if first_call: # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.first_call_op_prepare_before_hook( cur_module, args, kwargs, qtensor_id, fqn, cur_module, # type: ignore[arg-type] OpQuantizeabilityType.QUANTIZEABLE) else: # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs) # type: ignore[arg-type] # original forward output = orig_module_call(self, *args, **kwargs) # Re-enable the overrides. global_disable_torch_function_override = \ old_global_disable_torch_function_override # after hooks if first_call: output = parent_qstate.first_call_op_prepare_after_hook( cur_module, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE) else: output = parent_qstate.op_prepare_after_hook( cur_module, output, args, global_op_idx) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks if first_call: output = cur_qstate.first_call_outputs_prepare_hook( output, qtensor_id) else: output = cur_qstate.outputs_prepare_hook(output) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: if first_call and parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: args, kwargs = \ parent_qstate_fc.first_call_op_prepare_before_hook( cur_module, args, kwargs, qtensor_id, fqn, cur_module, OpQuantizeabilityType.NOT_QUANTIZEABLE) output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info if first_call and parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: output = \ parent_qstate_fc.first_call_op_prepare_after_hook( cur_module, output, args, qtensor_id, OpQuantizeabilityType.NOT_QUANTIZEABLE) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: if first_call: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) # Record module instances which are leaves or children of leaves leaves = set() for fqn, child in named_modules: if is_leaf(child, prepare_custom_config_dict): for _, child_child in child.named_modules(): leaves.add(child_child) self._fqn_to_auto_quant_state_map = AutoQuantizationStateModuleDict( ) for fqn, v in named_modules: # fqn is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = fqn if v in leaves: continue if v is self: # for the top level module only, specify input # and output dtypes auto_quant_state = AutoQuantizationState( qconfig_dict, fqn, input_dtypes, output_dtypes) else: auto_quant_state = AutoQuantizationState( qconfig_dict, fqn) # The code below registers the auto_quant_state object # of the child in the module hierarchy of the parent, # and adds the auto_quant_state object to the child # with a raw __setattr__, without registering it in # the module hierarchy of the child. # This is solving the problem of both storing extra state # (observers) as well as not modifying the meaning of user # code in child modules which iterates over all module # children. # # This narrows down the issue of dynamically adding # children to only affect the top level module and not # the children. # On the parent, register this module in the FQN map fqn_to_use_for_key = \ get_fqn_valid_for_module_dict_key(fqn) self._fqn_to_auto_quant_state_map[fqn_to_use_for_key] = \ auto_quant_state # On the child, manually set the attribute without # going through the `torch.nn.Module.__setattr__` # function, to prevent this object from appearing in # the child's module hierarchy. object.__setattr__(v, '_auto_quant_state', auto_quant_state) global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) if first_call: for _, v in self.named_modules(): if hasattr(v, '_auto_quant_state'): v._auto_quant_state.match_fusion_patterns() v._auto_quant_state.insert_observers(v) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False
def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} start") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) hook_type = get_module_hook_type(parent_module, cur_module) if first_call and hook_type is not HookType.OP_HOOKS and \ parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: parent_qstate_fc.add_seen_op_type_without_op_hooks( type(cur_module)) if hook_type is HookType.OP_HOOKS: parent_qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. nonlocal global_disable_torch_function_override old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs, first_call, qtensor_id, fqn, cur_module) # type: ignore[arg-type] # original forward output = orig_module_call(self, *args, **kwargs) # Re-enable the overrides. global_disable_torch_function_override = \ old_global_disable_torch_function_override # after hooks output = parent_qstate.op_prepare_after_hook( cur_module, output, args, first_call, qtensor_id, global_op_idx) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks output = cur_qstate.outputs_prepare_hook( output, first_call, qtensor_id) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: if first_call: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) # Record module instances which are leaves or children of leaves leaves = set() for fqn, child in named_modules: if is_leaf(child, prepare_custom_config_dict): for _, child_child in child.named_modules(): leaves.add(child_child) for fqn, v in named_modules: # fqn is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = fqn if v in leaves: continue if v is self: # for the top level module only, specify input # and output dtypes v._auto_quant_state = AutoQuantizationState( qconfig_dict, fqn, input_dtypes, output_dtypes) pass else: v._auto_quant_state = AutoQuantizationState( qconfig_dict, fqn) global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) if first_call: for _, v in self.named_modules(): if hasattr(v, '_auto_quant_state'): v._auto_quant_state.insert_observers(v) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. It also adds all weights the provided weights dictionary by qualified_name. Dictionary Schema: MODULE { modules: {module_name: MODULE], nodes: [NODE], weights {qualified_name: WEIGHT}, } NODE { shape: [], stride: [], dtype: dtype, is_quantized: bool, target: target, op_code: op_code, name: name, args: [], kwargs: {} } WEIGHT { dtype: dtype, is_quantized: bool, shape: [], QUANTIZATION, } QUANTIZATION { qscheme: qscheme, q_scale: float, q_zero_point: float, q_per_channel_scales, [], q_per_channel_zero_points: [], q_per_channel_axis, int } """ serialized_dict: Dict[str, Any] = {} serialized_dict["modules"] = {} serialized_dict["weights"] = {} serialized_dict["nodes"] = [] submodules = dict(fx_module.named_modules()) prefix = f"{name_prefix}." if name_prefix else "" def get_node_info(node): tensor_meta = get_tensor_meta(node) node_rep = { "shape": serialize_shape(tensor_meta.shape), "dtype": str(tensor_meta.dtype), "requires_grad": str(tensor_meta.requires_grad), "stride": serialize_stride(tensor_meta.stride), "is_quantized": tensor_meta.is_quantized, } if tensor_meta.is_quantized: node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"]) if tensor_meta.qparams["qscheme"] in { torch.per_tensor_affine, torch.per_tensor_symmetric, }: node_rep["q_scale"] = tensor_meta.qparams["scale"] node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"] # Add all extra lowering_info that was provided in node.meta. lowering_info = node.meta.get("lowering_info") if lowering_info is not None: overlapping_keys = node_rep.keys() & lowering_info.keys() assert ( len(overlapping_keys) == 0 ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}" node_rep.update(lowering_info) return node_rep # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules # that cannot currently be symbolically traced into, e.g. batch norm. lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module node # whose target is a GraphModule and output node. if (not (node.op == "call_module" and isinstance(submodules[node.target], GraphModule)) and node.op != "output"): node_rep.update(get_node_info(node)) # Recurse down into any submodules we are calling. if node.op == "call_module": if isinstance(submodules[node.target], GraphModule): serialized_module = serialize_module( getattr(fx_module, node.target), weights, node.target) serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( node, serialized_dict["weights"], weights, prefix + node.target, ) if node.op == "call_function": node_rep["target"] = _get_qualified_name(node.target) else: node_rep["target"] = str(node.target) # Make sure we capture all constants. if node.op == "get_attr": # If we are targeting a parent constant we update the target. if node.target.startswith("parent."): qualname = node.target[len("parent."):] node.name = qualname node_rep["target"] = qualname else: qualname = prefix + node.target # Find the actual target parameter/buffer from the fx_module. submod_path, _, target_name = node.target.rpartition(".") submod: Optional[torch.nn.Module] = ( fx_module.get_submodule(submod_path) if submod_path else fx_module) assert submod is not None, f"submod {submod_path} not found" target = getattr(submod, target_name, None) assert target is not None, f"{target_name} not an attr of {submod_path}" # Check that the target is a tensor, and that we haven't added it already from a leaf module. if isinstance(target, torch.Tensor) and qualname not in weights: weight = serialize_weight(target, weights, qualname) _update_weight_fused_dtypes(weight, qualname, node) serialized_dict["weights"].update(weight) weights[qualname] = target elif node.op == "placeholder": ph_type = node.meta.get("ph_type", "") assert ( ph_type == "" or ph_type == "input_ph" or ph_type == "output_ph" ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'" if ph_type == "input_ph": node_rep["ph_type"] = "input_ph" elif ph_type == "output_ph": node_rep["ph_type"] = "output_ph" node_rep["op_code"] = node.op node_rep["name"] = node.name def get_user_info(user_node: Argument) -> Any: return {"is_node": True, "name": str(user_node)} def get_arg_info(arg: Argument) -> Any: if isinstance(arg, torch.fx.Node): return {"is_node": True, "name": str(arg)} elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): return str(arg) else: return arg def get_output_arg_info(arg: Node) -> Dict[str, Any]: node_rep: Dict[str, Any] = get_arg_info(arg) node_rep.update(get_node_info(arg)) return node_rep if node.op == "output": node_rep["args"] = map_arg( node.args, get_output_arg_info, ) # If there're multiple outputs then node_rep["args"][0] will be a tuple or # list. In this case we want to unpack the tuple or list. if isinstance(node_rep["args"][0], (tuple, list)): node_rep["args"] = node_rep["args"][0] else: node_rep["args"] = map_aggregate(node.args, get_arg_info) node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info) node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info) serialized_dict["nodes"] += [node_rep] return serialized_dict