예제 #1
0
def custom_tensor_reshape_mapper(node: torch.fx.Node,
                                 _: nn.Module) -> torch.fx.Node:
    """
    For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)).
    Here we do some special handling with the `shape` arg in order to map it to
    acc_ops.reshape. It also handles the case when `shape` is a list instead of
    tuple.
    """
    input_node = node.kwargs["input"]
    shape = node.kwargs["shape"]

    assert isinstance(shape, Sequence)
    if isinstance(shape[0], (tuple, list)):  # type: ignore[index]
        shape = shape[0]  # type: ignore[index]

    with node.graph.inserting_before(node):
        new_node = node.graph.call_function(
            reshape,
            kwargs={
                "input": input_node,
                "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape),
            },
        )
        new_node.meta = node.meta.copy()
        return new_node
예제 #2
0
def packed_quantized_conv2d_mapper(node: torch.fx.Node,
                                   mod: nn.Module) -> torch.fx.Node:
    """
    Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters
    in this mapper and pass them directly to conv2d node.
    """
    assert isinstance(node.target, str)
    conv_module = dict(mod.named_modules())[node.target]
    prefix = node.target.replace(".", "_")
    weight_name = f"{prefix}_weight"
    bias_name = f"{prefix}_bias"

    # Store weight and bias in the main module
    mod.register_buffer(weight_name, conv_module.weight())
    if conv_module.bias() is not None:
        mod.register_buffer(bias_name, conv_module.bias())

    with node.graph.inserting_before(node):
        # Insert get_attr nodes for weight and bias
        get_weight = node.graph.get_attr(weight_name)
        get_weight.meta["tensor_meta"] = _extract_tensor_metadata(
            conv_module.weight())

        get_bias = None
        if conv_module.bias() is not None:
            get_bias = node.graph.get_attr(bias_name)
            get_bias.meta["tensor_meta"] = _extract_tensor_metadata(
                conv_module.bias())

        # Create kwargs for acc_op.conv
        kwargs = {
            "input":
            node.kwargs["input"],
            "weight":
            get_weight,
            "bias":
            get_bias,
            "stride":
            conv_module.stride,
            "padding":
            conv_module.padding,
            "dilation":
            conv_module.dilation,
            "groups":
            conv_module.groups,
            "padding_mode":
            conv_module.padding_mode,
            "acc_out_ty":
            acc_utils.build_raw_tensor_meta(
                q_scale=conv_module.scale,
                q_zero_point=conv_module.zero_point),
        }

        new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs)
        new_node.meta = node.meta
        return new_node
예제 #3
0
def move_kwargs_to_acc_out_ty(
    node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node],
    new_kwargs: Dict[str, Any],
):
    """
    Given `node_or_normalization_info` which is either NormalizationInfo for a node, or
    a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists
    in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty.
    """

    if isinstance(node_or_normalization_info, torch.fx.Node):
        node = node_or_normalization_info
        normalization_info = _normalization_dict.get((node.op, node.target))
    else:
        assert isinstance(node_or_normalization_info, NormalizationInfo)
        normalization_info = node_or_normalization_info

    assert normalization_info is not None
    if normalization_info.kwargs_to_move_to_acc_out_ty is None:
        return

    assert acc_utils.is_acc_op_with_kwarg(
        normalization_info.new_fn_target, "acc_out_ty"
    )

    # Build a dict representing the new TensorMetadata to use for acc_out_ty,
    # and then remove the kwarg from the new_kwargs since it's passed in via
    # acc_out_ty instead.
    tmd_dict: Dict[str, Any] = {}
    qparams: Dict[str, Any] = {}

    for kwarg_replacement_tuple in normalization_info.kwargs_to_move_to_acc_out_ty:
        if len(kwarg_replacement_tuple) == 2:
            orig_kwarg_name, tmd_field_name, move_to_qparams = *kwarg_replacement_tuple, False  # type: ignore[misc]
        else:
            assert len(kwarg_replacement_tuple) == 3
            orig_kwarg_name, tmd_field_name, move_to_qparams = kwarg_replacement_tuple  # type: ignore[misc]
        if move_to_qparams:
            qparams[tmd_field_name] = new_kwargs[orig_kwarg_name]
        else:
            tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name]
        del new_kwargs[orig_kwarg_name]

    tmd_dict["qparams"] = qparams
    # Note: allow_partial_spec here because we are only using the tensor metadata tuple
    # here to pass specific values into the function. For example, for quantization we
    # only need to provide qparams dictionary, but is_quantized is
    # not passed in.
    new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)
예제 #4
0
파일: acc_ops.py 프로젝트: vors/pytorch
def packed_quantized_linear_mapper(node: torch.fx.Node,
                                   mod: nn.Module) -> torch.fx.Node:
    """
    Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias
    in this mapper and pass them directly to linear node.
    """
    linear_module = dict(mod.named_modules())[node.target]
    prefix = node.target.replace(".", "_")
    weight_name = f"{prefix}_weight"
    bias_name = f"{prefix}_bias"

    # Store weight and bias in the main module
    mod.register_buffer(weight_name, linear_module.weight())
    if linear_module.bias() is not None:
        mod.register_buffer(bias_name, linear_module.bias())

    with node.graph.inserting_before(node):
        # Insert get_attr nodes for weight and bias
        get_weight = node.graph.get_attr(weight_name)
        get_weight.meta["tensor_meta"] = extract_tensor_metadata(
            linear_module.weight())

        get_bias = None
        if linear_module.bias() is not None:
            get_bias = node.graph.get_attr(bias_name)
            get_bias.meta["tensor_meta"] = extract_tensor_metadata(
                linear_module.bias())

        # Create kwargs for acc_op.quantized_linear
        kwargs = {
            "input":
            node.kwargs["input"],
            "weight":
            get_weight,
            "bias":
            get_bias,
            "acc_out_ty":
            acc_utils.build_raw_tensor_meta(
                q_scale=linear_module.scale,
                q_zero_point=linear_module.zero_point),
        }

        new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
        new_node.meta = node.meta
        return new_node
예제 #5
0
파일: acc_ops.py 프로젝트: eqy/pytorch
def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module):
    dest_dtype = node.kwargs["dtype"]
    mem_format = node.kwargs.get("memory_format")
    device = node.kwargs.get("device")
    assert dest_dtype is not None
    assert mem_format is None or mem_format == torch.preserve_format
    assert device is None

    new_kwargs = {
        "input": node.kwargs["input"],
        "acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype),
    }

    with node.graph.inserting_before(node):
        new_node = node.graph.create_node(
            "call_function", to_dtype, kwargs=new_kwargs, name=node.name
        )
        new_node.meta = node.meta
        return new_node
예제 #6
0
파일: acc_ops.py 프로젝트: eqy/pytorch
def add_relu_unfuse_mapper(
    node: torch.fx.Node, mod: torch.fx.GraphModule
) -> torch.fx.Node:
    with node.graph.inserting_before(node):
        add_kwargs = {
            "input": node.kwargs["input"],
            "other": node.kwargs["other"],
            "acc_out_ty": acc_utils.build_raw_tensor_meta(
                q_scale=node.kwargs["scale"],
                q_zero_point=node.kwargs["zero_point"],
            ),
        }
        add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs)
        add_node.meta = node.meta.copy()

        relu_node = node.graph.call_function(
            relu, kwargs={"input": add_node, "inplace": False}
        )
        relu_node.meta = node.meta
        return relu_node
예제 #7
0
def move_kwargs_to_acc_out_ty(
    node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node],
    new_kwargs: Dict[str, Any],
):
    """
    Given `node_or_normalization_info` which is either NormalizationInfo for a node, or
    a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists
    in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty.
    """
    if isinstance(node_or_normalization_info, torch.fx.Node):
        node = node_or_normalization_info
        normalization_info = _normalization_dict.get((node.op, node.target))
    else:
        normalization_info = node_or_normalization_info

    if normalization_info.kwargs_to_move_to_acc_out_ty is None:
        return

    assert acc_utils.is_acc_op_with_kwarg(normalization_info.new_fn_target,
                                          "acc_out_ty")

    # Build a dict representing the new TensorMetadata to use for acc_out_ty,
    # and then remove the kwarg from the new_kwargs since it's passed in via
    # acc_out_ty instead.
    tmd_dict: Dict[str, Any] = {}
    for (
            orig_kwarg_name,
            tmd_field_name,
    ) in normalization_info.kwargs_to_move_to_acc_out_ty:
        tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name]
        del new_kwargs[orig_kwarg_name]
    # Note: allow_partial_spec here because we are only using the tensor metadata tuple
    # here to pass specific values into the function. For example, for quantization we
    # only need to provide dtype/q_scale/q_zero_point, but is_quantized and qscheme are
    # not passed in.
    new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)