示例#1
0
def insert_dequantize_node(node: Node, graph: Graph):
    """ Inserts dequantize node for `node` in `graph`
    """
    with graph.inserting_after(node):
        dequantize_node = graph.call_method("dequantize", (node, ))
        for user_node in dict(node.users):
            if user_node is not dequantize_node:
                user_node.replace_input_with(node, dequantize_node)
示例#2
0
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module, graph: Graph, node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(
            node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map)
            for n in list(node.args) + list(node.users.keys())
        ])
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)
示例#3
0
    def replace_observer_with_quantize_dequantize_node(
            graph: Graph, node: Node, modules: Dict[str,
                                                    torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [
                torch.quint8, torch.qint8, torch.float16
        ]:
            node_type, quantize_op, qparams = get_quantize_node_info(
                observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)