Пример #1
0
def convert_standalone_module(node: Node, modules: Dict[str, torch.nn.Module],
                              model: torch.fx.GraphModule, is_reference: bool,
                              backend_config_dict: Optional[Dict[str, Any]]):
    """ Converts a observed standalone module to a quantized standalone module by calling
    the fx convert api, currently using the same `is_reference` flag as parent, but we may
    changing this behavior in the future (e.g. separating quantization and lowering for
    standalone module as well)

    Args:
      - node: The call_module node of the observed standalone module
      - modules: named_module of original model
      - model: original model
      - is_reference: a flag from parent provided by user to decide if we want to
        produce a reference model or a fbgemm/qnnpack model
      - backend_config_dict: backend configuration of the target backend of quantization
    """
    # TODO: remove is_reference flag
    if is_reference:
        convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
    else:
        convert_fn = torch.ao.quantization.quantize_fx.convert_fx  # type: ignore[attr-defined]
    # We know that observed standalone module is a GraphModule since
    # it's produced by us
    observed_standalone_module: GraphModule = modules[str(
        node.target)]  # type: ignore[assignment]
    sm_input_quantized_idxs = \
        observed_standalone_module \
        ._standalone_module_input_quantized_idxs\
        .tolist()  # type: ignore[operator]
    # remove the dequantize nodes for inputs
    args = list(node.args)
    for idx in range(len(args)):
        if idx in sm_input_quantized_idxs:
            arg = args[idx]
            if arg.op == "call_method" and arg.target == "dequantize":  # type: ignore[union-attr]
                quantize_node = arg.args[0]  # type: ignore[union-attr]
                node.replace_input_with(arg, quantize_node)
                if len(arg.users) == 0:  # type: ignore[union-attr]
                    model.graph.erase_node(arg)
    # add dequantize node for output
    sm_output_quantized_idxs = \
        observed_standalone_module \
        ._standalone_module_output_quantized_idxs \
        .tolist()  # type: ignore[operator]
    if len(sm_output_quantized_idxs) > 0:
        assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
        "output idxs = [0] is supported"

        # if it's non-empty, then it means the output is kept in quantized form
        # we'll just add a dequantize node after this node
        insert_dequantize_node(node, model.graph)

    # TODO: allow convert_custom_config to override backend_config_dict
    # for standalone module
    quantized_standalone_module = convert_fn(
        observed_standalone_module, backend_config_dict=backend_config_dict)
    parent_name, name = _parent_name(node.target)
    # update the modules dict
    setattr(modules[parent_name], name, quantized_standalone_module)
    modules[str(node.target)] = quantized_standalone_module
Пример #2
0
def convert_standalone_module(
        node: Node,
        modules: Dict[str, torch.nn.Module],
        model: torch.fx.GraphModule,
        is_reference: bool,
        backend_config_dict: Dict[str, Any]):
    convert = torch.ao.quantization._quantize_fx_do_not_use._convert_do_not_use  # type: ignore[attr-defined]
    # We know that observed standalone module is a GraphModule since
    # it's produced by us
    observed_standalone_module : GraphModule = modules[str(node.target)]  # type: ignore[assignment]
    sm_input_quantized_idxs = \
        observed_standalone_module \
        ._standalone_module_input_quantized_idxs\
        .tolist()  # type: ignore[operator]
    # remove the dequantize nodes for inputs
    args = list(node.args)
    for idx in range(len(args)):
        if idx in sm_input_quantized_idxs:
            arg = args[idx]
            if arg.op == "call_method" and arg.target == "dequantize":  # type: ignore[union-attr]
                quantize_node = arg.args[0]  # type: ignore[union-attr]
                node.replace_input_with(arg, quantize_node)
                if len(arg.users) == 0:  # type: ignore[union-attr]
                    model.graph.erase_node(arg)
    # add dequantize node for output
    sm_output_quantized_idxs = \
        observed_standalone_module \
        ._standalone_module_output_quantized_idxs \
        .tolist()  # type: ignore[operator]
    if len(sm_output_quantized_idxs) > 0:
        assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
        "output idxs = [0] is supported"

        # if it's non-empty, then it means the output is kept in quantized form
        # we'll just add a dequantize node after this node
        insert_dequantize_node(node, model.graph)

    # TODO: allow convert_custom_config_dict to override backend_config_dict
    # for standalone module
    quantized_standalone_module = convert(
        observed_standalone_module,
        is_reference=True,
        backend_config_dict=backend_config_dict)
    parent_name, name = _parent_name(node.target)
    # update the modules dict
    setattr(modules[parent_name], name, quantized_standalone_module)
    modules[str(node.target)] = quantized_standalone_module
Пример #3
0
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
    """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
    we'll recursively remove the dequantize Node
    """
    if isinstance(arg, Node) and \
       arg.op == "call_method" and \
       arg.target == "dequantize":
        quantize_node = arg.args[0]
        # we only replace the specific use since dequantize could be used by other nodes
        # as well
        node.replace_input_with(arg, quantize_node)
    elif isinstance(arg, (list, tuple)):
        for arg_element in arg:
            maybe_recursive_remove_dequantize(arg_element, node, graph)
    elif isinstance(arg, dict):
        for arg_element in arg.values():
            maybe_recursive_remove_dequantize(arg_element, node, graph)
    else:
        warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
Пример #4
0
def convert_custom_module(
        node: Node,
        graph: Graph,
        modules: Dict[str, torch.nn.Module],
        custom_module_class_mapping: Dict[Callable, Callable],
        statically_quantized_custom_module_nodes: Set[Node]):
    """ Converts an observed custom module to a quantized custom module based on
    `custom_module_class_mapping`
    For static quantization, we'll also remove the previous `dequantize` node and
    attach the observer node for output to the module, the observer for the node
    will be converted to a dequantize node instead of quantize-dequantize pairs
    later in the graph. In the end we would have a quantized custom module that
    has the same interface as a default quantized module in nn.quantized namespace,
    i.e. quantized input and quantized output.

    Args:
      - node: The call_module node of the observed standalone module
      - graph: The graph containing the node
      - modules: named_module of original model
      - custom_module_class_mapping: mapping from observed custom module class to
        quantized custom module class, used to swap custom modules
      - statically_quantized_custom_module_nodes: we'll add the custom module node
        if we find it is statically quantized, this will be used later when converting
        observers to quant/dequant node pairs, if the observed node is a statically
        quantized custom module nodes, we'll convert the observer to a dequantize node,
        this is to keep the interface the same as the default quantized module.
        TODO: maybe we want to redesign this part to align with reference model design
        as well, but there has been some discussions around the interface, so we can do
        it later.
    """
    observed_custom_module = modules[str(node.target)]
    maybe_obs = maybe_get_observer_for_node(node, modules)
    qconfig = observed_custom_module.qconfig
    if activation_is_statically_quantized(qconfig):
        statically_quantized_custom_module_nodes.add(node)
        # remove the previous dequant node
        prev_node = node.args[0]
        # expecting the input node for a custom module node to be a Node
        assert isinstance(prev_node, Node), \
            f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
        if prev_node.op == "call_method" and prev_node.target == "dequantize":
            # change the connection for custom module, we'll change the input
            # of custom module node to quantize node:
            # Before: quantize - dequantize - custom - module
            # After: quantize - custom - module
            #              \ - dequantize
            node.replace_input_with(prev_node, prev_node.args[0])

            # Remove the dequantize node if it doesn't have other users
            if len(prev_node.users) == 0:
                graph.erase_node(prev_node)

        # absorb the following observer into the module conversion
        activation_post_process = maybe_get_observer_for_node(node, modules)
        assert activation_post_process is not None
        observed_custom_module.activation_post_process = activation_post_process

    # swap the observed custom module to quantized custom module
    quantized_custom_module_class = get_swapped_custom_module_class(
        observed_custom_module, custom_module_class_mapping, qconfig)
    quantized_custom_module = \
        quantized_custom_module_class.from_observed(observed_custom_module)
    parent_name, name = _parent_name(node.target)
    setattr(modules[parent_name], name, quantized_custom_module)