def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: """ For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)). Here we do some special handling with the `shape` arg in order to map it to acc_ops.reshape. It also handles the case when `shape` is a list instead of tuple. """ input_node = node.kwargs["input"] shape = node.kwargs["shape"] assert isinstance(shape, Sequence) if isinstance(shape[0], (tuple, list)): # type: ignore[index] shape = shape[0] # type: ignore[index] with node.graph.inserting_before(node): new_node = node.graph.call_function( reshape, kwargs={ "input": input_node, "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape), }, ) new_node.meta = node.meta.copy() return new_node
def packed_quantized_conv2d_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: """ Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters in this mapper and pass them directly to conv2d node. """ assert isinstance(node.target, str) conv_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" bias_name = f"{prefix}_bias" # Store weight and bias in the main module mod.register_buffer(weight_name, conv_module.weight()) if conv_module.bias() is not None: mod.register_buffer(bias_name, conv_module.bias()) with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) get_weight.meta["tensor_meta"] = _extract_tensor_metadata( conv_module.weight()) get_bias = None if conv_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) get_bias.meta["tensor_meta"] = _extract_tensor_metadata( conv_module.bias()) # Create kwargs for acc_op.conv kwargs = { "input": node.kwargs["input"], "weight": get_weight, "bias": get_bias, "stride": conv_module.stride, "padding": conv_module.padding, "dilation": conv_module.dilation, "groups": conv_module.groups, "padding_mode": conv_module.padding_mode, "acc_out_ty": acc_utils.build_raw_tensor_meta( q_scale=conv_module.scale, q_zero_point=conv_module.zero_point), } new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs) new_node.meta = node.meta return new_node
def move_kwargs_to_acc_out_ty( node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node], new_kwargs: Dict[str, Any], ): """ Given `node_or_normalization_info` which is either NormalizationInfo for a node, or a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty. """ if isinstance(node_or_normalization_info, torch.fx.Node): node = node_or_normalization_info normalization_info = _normalization_dict.get((node.op, node.target)) else: assert isinstance(node_or_normalization_info, NormalizationInfo) normalization_info = node_or_normalization_info assert normalization_info is not None if normalization_info.kwargs_to_move_to_acc_out_ty is None: return assert acc_utils.is_acc_op_with_kwarg( normalization_info.new_fn_target, "acc_out_ty" ) # Build a dict representing the new TensorMetadata to use for acc_out_ty, # and then remove the kwarg from the new_kwargs since it's passed in via # acc_out_ty instead. tmd_dict: Dict[str, Any] = {} qparams: Dict[str, Any] = {} for kwarg_replacement_tuple in normalization_info.kwargs_to_move_to_acc_out_ty: if len(kwarg_replacement_tuple) == 2: orig_kwarg_name, tmd_field_name, move_to_qparams = *kwarg_replacement_tuple, False # type: ignore[misc] else: assert len(kwarg_replacement_tuple) == 3 orig_kwarg_name, tmd_field_name, move_to_qparams = kwarg_replacement_tuple # type: ignore[misc] if move_to_qparams: qparams[tmd_field_name] = new_kwargs[orig_kwarg_name] else: tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name] del new_kwargs[orig_kwarg_name] tmd_dict["qparams"] = qparams # Note: allow_partial_spec here because we are only using the tensor metadata tuple # here to pass specific values into the function. For example, for quantization we # only need to provide qparams dictionary, but is_quantized is # not passed in. new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)
def packed_quantized_linear_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: """ Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias in this mapper and pass them directly to linear node. """ linear_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" bias_name = f"{prefix}_bias" # Store weight and bias in the main module mod.register_buffer(weight_name, linear_module.weight()) if linear_module.bias() is not None: mod.register_buffer(bias_name, linear_module.bias()) with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) get_weight.meta["tensor_meta"] = extract_tensor_metadata( linear_module.weight()) get_bias = None if linear_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) get_bias.meta["tensor_meta"] = extract_tensor_metadata( linear_module.bias()) # Create kwargs for acc_op.quantized_linear kwargs = { "input": node.kwargs["input"], "weight": get_weight, "bias": get_bias, "acc_out_ty": acc_utils.build_raw_tensor_meta( q_scale=linear_module.scale, q_zero_point=linear_module.zero_point), } new_node = node.graph.call_function(quantized_linear, kwargs=kwargs) new_node.meta = node.meta return new_node
def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module): dest_dtype = node.kwargs["dtype"] mem_format = node.kwargs.get("memory_format") device = node.kwargs.get("device") assert dest_dtype is not None assert mem_format is None or mem_format == torch.preserve_format assert device is None new_kwargs = { "input": node.kwargs["input"], "acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype), } with node.graph.inserting_before(node): new_node = node.graph.create_node( "call_function", to_dtype, kwargs=new_kwargs, name=node.name ) new_node.meta = node.meta return new_node
def add_relu_unfuse_mapper( node: torch.fx.Node, mod: torch.fx.GraphModule ) -> torch.fx.Node: with node.graph.inserting_before(node): add_kwargs = { "input": node.kwargs["input"], "other": node.kwargs["other"], "acc_out_ty": acc_utils.build_raw_tensor_meta( q_scale=node.kwargs["scale"], q_zero_point=node.kwargs["zero_point"], ), } add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs) add_node.meta = node.meta.copy() relu_node = node.graph.call_function( relu, kwargs={"input": add_node, "inplace": False} ) relu_node.meta = node.meta return relu_node
def move_kwargs_to_acc_out_ty( node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node], new_kwargs: Dict[str, Any], ): """ Given `node_or_normalization_info` which is either NormalizationInfo for a node, or a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty. """ if isinstance(node_or_normalization_info, torch.fx.Node): node = node_or_normalization_info normalization_info = _normalization_dict.get((node.op, node.target)) else: normalization_info = node_or_normalization_info if normalization_info.kwargs_to_move_to_acc_out_ty is None: return assert acc_utils.is_acc_op_with_kwarg(normalization_info.new_fn_target, "acc_out_ty") # Build a dict representing the new TensorMetadata to use for acc_out_ty, # and then remove the kwarg from the new_kwargs since it's passed in via # acc_out_ty instead. tmd_dict: Dict[str, Any] = {} for ( orig_kwarg_name, tmd_field_name, ) in normalization_info.kwargs_to_move_to_acc_out_ty: tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name] del new_kwargs[orig_kwarg_name] # Note: allow_partial_spec here because we are only using the tensor metadata tuple # here to pass specific values into the function. For example, for quantization we # only need to provide dtype/q_scale/q_zero_point, but is_quantized and qscheme are # not passed in. new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)