Beispiel #1
    def fuse(self, model, inplace=False):
        if not inplace:
            model = copy.deepcopy(model)
        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        fusion_patterns = get_default_fusion_patterns()
        # find fusion
        fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
        self.fused_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(, (None, None))
            if root_node is node:
                env[] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        model = GraphModule(input_root, self.fused_graph)
        return model
Beispiel #2
def insert_observer(
    node: Node,
    observed_op: Node,
    observer: ObserverBase,
    model: torch.nn.Module,
    modules: Dict[str, torch.nn.Module],
    graph: Graph,
    node_name_to_scope: Dict[str, Tuple[str, type]],
    input_or_output: str,
) -> Node:
    Attaches `observer` to `model`, and creates a node which calls
    `observer` on the output of `node`.
    model_device = assert_and_get_unique_device(model)
    if model_device:
    # add observer module as attribute
    # NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope
    # Please don't change/update this behavior as it might impact how observer stats are transferred
    # from the train model to the inference model for some models.
    obs_name_prefix, _ = node_name_to_scope[]
    obs_name_prefix = if obs_name_prefix == '' else obs_name_prefix
    if is_equalization_observer(observer):
        prefix = + '_equalization_process_'
        prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    modules[observer_name] = observer
    with graph.inserting_after(node):
        new_obs = graph.create_node(
            'call_module', observer_name, (node,), {})
    return new_obs
Beispiel #3
def remove_qconfig_observer_fx(model):
    # remove activation post process
    act_post_process_removed_graph = Graph()
    env = {}  # type: Dict[str, Any]

    modules = dict(model.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[])

    for node in model.graph.nodes:
        if node.op == "output":
                map_arg(node.args[0], load_arg))
        if node.op == "call_module" and is_activation_post_process(
            # remove activation post process node
            env[] = env[node.args[0].name]
            env[] = act_post_process_removed_graph.node_copy(
                node, load_arg)

    model = GraphModule(model, act_post_process_removed_graph)
    return model
Beispiel #4
def _insert_quantize_per_tensor_node(
    prev_node_c: Node,
    node_a: Node,
    gm_b: GraphModule,
    graph_c: Graph,
    scale: Union[torch.Tensor, float],
    zero_point: Union[torch.Tensor, int],
    dtype_cast_name: str,
) -> Node:
    # copy scale
    scale_node_name = \
   + '_input_scale_')(gm_b)
    setattr(gm_b, scale_node_name, scale)
    scale_node = graph_c.create_node('get_attr', scale_node_name, (), {},
    # copy zero_point
    zero_point_node_name = \
   + '_input_zero_point_')(gm_b)
    setattr(gm_b, zero_point_node_name, zero_point)
    zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (),
                                          {}, zero_point_node_name)
    # create the quantize_per_tensor call
    return graph_c.create_node(
        'call_function', torch.quantize_per_tensor,
        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
Beispiel #5
def insert_observer(
    node: Node,
    observer: torch.quantization.ObserverBase,
    model: torch.nn.Module,
    modules: Dict[str, torch.nn.Module],
    graph: Graph,
) -> Node:
    Attaches `observer` to `model`, and creates a node which calls
    `observer` on the output of `node`.
    model_device = assert_and_get_unique_device(model)
    if model_device:
    # add observer module as attribute
    if is_equalization_observer(observer):
        prefix = + '_equalization_process_'
        prefix = + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    modules[observer_name] = observer
    with graph.inserting_after(node):
        new_obs = graph.create_node(
            'call_module', observer_name, (node,), {})
    return new_obs
Beispiel #6
    def fuse(self, model, fuse_custom_config_dict=None):
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        additional_fusion_patterns = fuse_custom_config_dict.get(
            "additional_quant_pattern", {})
        fusion_patterns = get_default_fusion_patterns().copy()
        for k, v in additional_fusion_patterns.items():
            fusion_patterns[k] = v
        # find fusion
        fusion_pairs = self._find_matches(input_root, input_graph,
        self.fused_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(, (None, None))
            if root_node is node:
                env[] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        model = GraphModule(input_root, self.fused_graph)
        return model
Beispiel #7
    def fuse(self, model: GraphModule,
             fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        additional_fusion_patterns = \
            fuse_custom_config_dict.get("additional_fusion_pattern", {})
        fusion_patterns = get_combined_dict(
            get_default_fusion_patterns(), additional_fusion_patterns)
        # find fusion
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_patterns)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(, (None, None))
            if root_node is node:
                assert obj is not None
                env[] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        model = GraphModule(input_root, self.fused_graph)
        return model
Beispiel #8
 def replace_observer_with_dequantize_node(node: Node, graph: Graph):
     call_custom_module_node = node.args[0]
     assert isinstance(call_custom_module_node, Node), \
         f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
     insert_dequantize_node(call_custom_module_node, graph)
Beispiel #9
def convert_custom_module(node: Node, graph: Graph,
                          modules: Dict[str, torch.nn.Module],
                          custom_module_class_mapping: Dict[Callable,
                          statically_quantized_custom_module_nodes: Set[Node]):
    """ Converts an observed custom module to a quantized custom module based on
    For static quantization, we'll also remove the previous `dequantize` node and
    attach the observer node for output to the module, the observer for the node
    will be converted to a dequantize node instead of quantize-dequantize pairs
    later in the graph. In the end we would have a quantized custom module that
    has the same interface as a default quantized module in nn.quantized namespace,
    i.e. quantized input and quantized output.

      - node: The call_module node of the observed standalone module
      - graph: The graph containing the node
      - modules: named_module of original model
      - custom_module_class_mapping: mapping from observed custom module class to
        quantized custom module class, used to swap custom modules
      - statically_quantized_custom_module_nodes: we'll add the custom module node
        if we find it is statically quantized, this will be used later when converting
        observers to quant/dequant node pairs, if the observed node is a statically
        quantized custom module nodes, we'll convert the observer to a dequantize node,
        this is to keep the interface the same as the default quantized module.
        TODO: maybe we want to redesign this part to align with reference model design
        as well, but there has been some discussions around the interface, so we can do
        it later.
    observed_custom_module = modules[str(]
    maybe_obs = maybe_get_observer_for_node(node, modules)
    qconfig = observed_custom_module.qconfig
    if activation_is_statically_quantized(qconfig):
        # remove the previous dequant node
        prev_node = node.args[0]
        # expecting the input node for a custom module node to be a Node
        assert isinstance(prev_node, Node), \
            f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
        if prev_node.op == "call_method" and == "dequantize":
            assert len(prev_node.users
                       ) == 1, "dequantize node before custom module is used "
            "multiple times, this is currently not supported yet, but it can be "
            "supported by duplicating the dequantize nodes in these cases"

        # absorb the following observer into the module conversion
        activation_post_process = maybe_get_observer_for_node(node, modules)
        assert activation_post_process is not None
        observed_custom_module.activation_post_process = activation_post_process

    # swap the observed custom module to quantized custom module
    quantized_custom_module_class = get_swapped_custom_module_class(
        observed_custom_module, custom_module_class_mapping, qconfig)
    quantized_custom_module = \
    parent_name, name = _parent_name(
    setattr(modules[parent_name], name, quantized_custom_module)
Beispiel #10
def fold_weight(
        quantized: QuantizedGraphModule,
        node_name_to_scope: Dict[str, Tuple[str,
                                            type]]) -> QuantizedGraphModule:
    Trace back from the weight node util we hit getattr, reconstruct the
    graph module with the traced nodes and run the graph module to pack the
    weight. then replace the original chain of ops with the packed weight.
    packed_weights = dict()
    # map from folded node name to the prepacked weight name
    folded_nodes = dict()
    # get packed weights
    for node in quantized.graph.nodes:
        if node.op == 'call_function' and in WEIGHT_PREPACK_OPS:
            nodes_to_fold = collect_producer_nodes(node)
            if nodes_to_fold is not None:
                for node_to_fold in nodes_to_fold:
                    folded_nodes[] = node

                prepacking_module = graph_module_from_producer_nodes(
                    quantized, nodes_to_fold)
                packed_weight = prepacking_module()
                packed_weights[] = packed_weight

    # remove folded nodes and replace the prepacking node with getattr
    folded_graph = Graph()
    env: Dict[Any, Any] = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[])

    quantized_root = quantized
    quantized_graph = quantized.graph

    for node in quantized_graph.nodes:
        prepack_node = folded_nodes.get(, None)
        if prepack_node is node:
            packed_weight = packed_weights[]
            # add a prepacked attribute to root
            op_node = list(prepack_node.users)[0]
            module_path, _ = node_name_to_scope[]
            get_new_packed_weight_name = \
                get_new_attr_name_with_prefix(module_path + '_packed_weight_')
            packed_weight_name = get_new_packed_weight_name(quantized_root)
            setattr(quantized_root, packed_weight_name, packed_weight)
            # replace prepack node with a getattr node
            env[] = folded_graph.create_node('get_attr',
                                                      packed_weight_name, (),
        elif prepack_node is not None:
            # remove the foled node
            # copy other nodes
            env[] = folded_graph.node_copy(node, load_arg)
    quantized = QuantizedGraphModule(quantized_root, folded_graph,
    return quantized
Beispiel #11
def insert_dequantize_node(node: Node, graph: Graph):
    """ Inserts dequantize node for `node` in `graph`
    with graph.inserting_after(node):
        dequantize_node = graph.call_method("dequantize", (node, ))
        for user_node in dict(node.users):
            if user_node is not dequantize_node:
                user_node.replace_input_with(node, dequantize_node)
Beispiel #12
    def __init__(self):
        self.graph = Graph()

        # Private attributes for generating names
        self._tensor_name_counter = 0
        self._dim_name_counter = 0
        self._shape_name_counter = 0
        self._lowercase = tuple(string.ascii_lowercase)
        self._uppercase = tuple(string.ascii_uppercase)
Beispiel #13
def legalize_graph(gm: GraphModule):
    Replace the graph of the given GraphModule with one that contains the same nodes as the
    original, but in topologically sorted order.

    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
    order of its input GraphModule, so that this order is restored before further transformation.

        gm: The graph module to topologically sort. It is modified in-place.

    # Build an adjacency list representation of node dependencies in the graph. This also
    # serves as a list of nodes that still need to be inserted into the new, topologically
    # sorted graph.
    dependencies = {
        node: node.all_input_nodes.copy()
        for node in gm.graph.nodes

    # Construct a new graph that will contain all nodes in topologically sorted order.
    new_graph = Graph()
    value_remap: Dict[Node, Node] = {}

    # Copy over all nodes with no dependencies.
    for node, deps in dependencies.items():
        if not deps:
            value_remap[node] = new_graph.node_copy(node,
                                                    lambda n: value_remap[n])

    # Remove the copied over nodes from the adjacency list.
    for copied_node in value_remap.keys():
        del dependencies[copied_node]

    # While there are still nodes to insert into the new graph:
    while dependencies:
        copied_this_round = []

        # Copy over all nodes whose dependencies already exist in the new graph.
        for node, deps in dependencies.items():
            all_deps_copied = True
            for dep in deps:
                if dep not in value_remap:
                    all_deps_copied = False

            if all_deps_copied:
                value_remap[node] = new_graph.node_copy(
                    node, lambda n: value_remap[n])

        # Delete all nodes copied over in this iteration from dependencies.
        for copied_node in copied_this_round:
            del dependencies[copied_node]

    # Replace the old graph with the new, topologically sorted one.
    gm.graph = new_graph
Beispiel #14
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
) -> Union[Node, List[Node]]:
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    dtype_cast_op = None
    node_input_type_a = get_node_input_type(node_a, gm_a)
    node_input_type_c = get_node_input_type(node_c, gm_b)

    if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8:
        dtype_cast_op = torch.dequantize
        raise AssertionError(
            f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented"

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \

        return graph_c.create_node('call_function', dtype_cast_op,
                                   (prev_node_c, ), {}, new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \

            new_dtype_cast_node = graph_c.create_node('call_function',
                                                      (prev_node_c_inner, ),
                                                      {}, new_dtype_cast_name)
        return results
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
Beispiel #15
class Fuser:
    def fuse(self, model, inplace=False):
        if not inplace:
            model = copy.deepcopy(model)
        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        fusion_patterns = get_fusion_patterns()
        # find fusion
        fusion_pairs = self._find_matches(input_root, input_graph,
        self.fused_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(, (None, None))
            if root_node is node:
                env[] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        model = GraphModule(input_root, self.fused_graph)
        return model

    def _find_matches(self, root, graph, patterns):
        modules = dict(root.named_modules())
        match_map = {}  # node name -> (root_node, match_value?)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
                # the first pattern matches will take precedence
                if not in match_map:
                    match_map[] = match

        for node in reversed(graph.nodes):
            if not in match_map:
                for pattern, value in patterns.items():
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map
Beispiel #16
    def fuse(
        model: GraphModule,
        fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None,
    ) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        if backend_config_dict is None:
            additional_fusion_patterns = \
                fuse_custom_config_dict.get("additional_fusion_pattern", {})
            fusion_pattern_to_fuse_handler_cls = get_combined_dict(
                get_default_fusion_patterns(), additional_fusion_patterns)
            fuser_method_mapping = None
            fusion_pattern_to_fuse_handler_cls = get_fusion_pattern_to_fuse_handler_cls(backend_config_dict)
            fuser_method_mapping = get_fuser_method_mapping(backend_config_dict)
        # find fusion
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            maybe_last_node, pattern, matched_node_pattern, obj = \
                fusion_pairs.get(, (None, None, None, None))
            if maybe_last_node is node:
                assert obj is not None
                # TODO: currently we hard code the root node, which only works for
                # a tuple of two nodes, we want to make this more general to
                # support more complex patterns
                root_node = matched_node_pattern[-1]  # type: ignore[index]
                env[] = obj.fuse(
                    self, load_arg, root_node, matched_node_pattern,  # type: ignore[arg-type]
                    fuse_custom_config_dict, fuser_method_mapping)
            elif maybe_last_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
        model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
        return model
Beispiel #17
def _copy_node_from_a_to_c(
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
) -> Node:
    Simple copy of node_a to graph_c.
    if node_a.op == 'get_attr':
        node_a_copy_name = \
            get_new_attr_name_with_prefix( + '_shadow_copy_')(gm_b)
        node_a_obj = getattr_from_fqn(gm_a,
                              # type: ignore[arg-type]
        if torch.is_tensor(node_a_obj):
            node_a_obj = node_a_obj.detach()
        setattr(gm_b, node_a_copy_name, node_a_obj)
        node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {},
        return node_a_copy
    elif node_a.op == 'call_method':
        assert in ('dequantize', 'to'), \
            f"target {} is not implemented"
        if == 'dequantize':
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b,
                graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix( + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(node_a.op,,
                                              (arg_copy, ), {},
            return node_a_copy
        else:  # to
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b,
                graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix( + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(
                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), {},
            return node_a_copy

        raise AssertionError(
            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
Beispiel #18
class Fuser:
    def fuse(self, model: GraphModule,
             fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        additional_fusion_patterns = \
            fuse_custom_config_dict.get("additional_fusion_pattern", {})
        fusion_patterns = get_combined_dict(
            get_default_fusion_patterns(), additional_fusion_patterns)
        # find fusion
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_patterns)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(, (None, None))
            if root_node is node:
                assert obj is not None
                env[] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        model = GraphModule(input_root, self.fused_graph)
        return model

    def _find_matches(
            self, root: GraphModule, graph: Graph,
            patterns: Dict[Pattern, Callable]
    ) -> Dict[str, Tuple[Node, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map : Dict[str, Tuple[Node, FuseHandler]] = {}  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
                # the first pattern matches will take precedence
                if not in match_map:
                    match_map[] = match

        for node in reversed(graph.nodes):
            if not in match_map:
                for pattern, value in patterns.items():
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map
Beispiel #19
def insert_observer(
        node: Node, observer: torch.quantization.ObserverBase,
        model: torch.nn.Module,
        activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
        env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
        observed_node_names_set: Set[str]):
    """Insert observer for node by modifying the observed_graph and
       attach observer module to the model
         node: Node
         observer: observer/fake_quantize module instance
    # respect device affinity when adding observers
    model_device = assert_and_get_unique_device(model)
    if model_device:
    # add observer module as attribute
    prefix = + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    # put observer instance activation_post_process map
    assert activation_post_process_map is not None
    activation_post_process_map[] = observer
    # insert observer call
    env[] = observed_graph.create_node(
        'call_module', observer_name, (load_arg(node),), {})
Beispiel #20
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module, graph: Graph, node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        ... -> observer_0(x) -> ...
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        assert modules is not None
        assert isinstance(, str)
        module_path, prefix = get_module_path_and_prefix(
            node, node_name_to_scope, qconfig_map)
        observer_module = modules[]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map)
            for n in list(node.args) + list(node.users.keys())
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            model, graph, module_path + prefix + key, value)
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
Beispiel #21
    def _fold_weight(self, quantized):
        packed_weights = dict()
        # map from folded node name to the prepacked weight name
        folded_nodes = dict()
        # get packed weights
        for node in quantized.graph.nodes:
            if node.op == 'call_function' and in WEIGHT_PREPACK_OPS:
                nodes_to_fold = collect_producer_nodes(node)
                if nodes_to_fold is not None:
                    for node_to_fold in nodes_to_fold:
                        folded_nodes[] = node

                    prepacking_module = graph_module_from_producer_nodes(
                        quantized, nodes_to_fold)
                    packed_weight = prepacking_module()
                    packed_weights[] = packed_weight

        # remove folded nodes and replace the prepacking node with getattr
        folded_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[])

        get_new_packed_weight_name = get_new_attr_name_with_prefix(
        quantized_root = quantized
        quantized_graph = quantized.graph
        for node in quantized_graph.nodes:
            prepack_node = folded_nodes.get(, None)
            if prepack_node is node:
                packed_weight = packed_weights[]
                # add a prepacked attribute to root
                packed_weight_name = get_new_packed_weight_name(quantized_root)
                setattr(quantized_root, packed_weight_name, packed_weight)
                # replace prepack node with a getattr node
                env[] = folded_graph.create_node(
                    'get_attr', packed_weight_name, (), {})
            elif prepack_node is not None:
                # remove the foled node
                # copy other nodes
                env[] = folded_graph.node_copy(node, load_arg)
        quantized = GraphModule(quantized_root, folded_graph)
        return quantized
    def convert(self,
                node: Node,
                qconfig: QConfigAny,
                modules: Dict[str, torch.nn.Module],
                quantized_graph: Graph,
                node_name_to_scope: Dict[str, Tuple[str, type]],
                load_arg: Callable,
                is_reference: bool = False,
                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
        if not self.all_node_args_are_tensors:
            return NotImplemented
        assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
            'call_function are handled in DefaultNode'
        assert is_reference
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        additional_static_quant_mapping = convert_custom_config_dict.get(
            "static", {})

        dtypes = get_qconfig_dtypes(qconfig)
        # We can produce reference for a dtypes including
        # (torch.quint8, torch.qint8, torch.qint32, torch.float16)
        act_dtype = activation_dtype(qconfig)
        if act_dtype == torch.float:
            op_out = quantized_graph.node_copy(node,
            return op_out
            activation_post_process = \
            assert activation_post_process is not None
            # make sure the input is quantized to act_dtype
            load_arg(quantized={0: act_dtype})(node.args)
            args = load_arg(quantized=torch.float)(node.args)
            kwargs = load_arg(quantized=torch.float)(node.kwargs)
            op_out = quantized_graph.node_copy(node,
            return quantize_node(op_out,
Beispiel #23
def create_node_from_old_node_preserve_meta(
    quantized_graph: Graph,
    create_node_args: Tuple[Any, ...],
    old_node: Node,
) -> Node:
    Creates `new_node` and copies the necessary metadata to it from `old_node`.
    new_node = quantized_graph.create_node(*create_node_args)
    new_node.stack_trace = old_node.stack_trace
    return new_node
Beispiel #24
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node:
    Given a value of any type, creates a getattr node corresponding to the value and
    registers the value as a buffer to the module.
    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
    attr_name = get_new_attr_name(module)
    module.register_buffer(attr_name, torch.tensor(value))
    # Create get_attr with value
    attr_node = graph.create_node("get_attr", attr_name)
    return attr_node
Beispiel #25
def replace_target_nodes_with(
    fx_module: GraphModule,
    old_op: str,
    old_target: Target,
    new_op: str,
    new_target: Target,
    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
    and updates them to match the new op code and target"""
    new_graph = Graph()
    val_map : Dict[Node, Node] = {}
    for node in fx_module.graph.nodes:
        if node.op == old_op and == old_target:
            args = map_arg(node.args, lambda n: val_map[n])
            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
            assert isinstance(args, tuple)
            assert isinstance(kwargs, dict)
            val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs,
            val_map[node] = new_graph.node_copy(node, lambda n : val_map[n])
    fx_module.graph = new_graph
Beispiel #26
def graph_module_from_producer_nodes(root, producer_nodes):
    r''' Construct a graph module from extracted producer nodes
    from `collect_producer_nodes` function
      root: the root module for the original graph
      producer_nodes: a list of nodes we use to construct the graph
      A graph module constructed from the producer nodes
    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
    # since we traced back from node to getattrr
    graph = Graph()
    env = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[])
    for producer_node in producer_nodes:
        env[] = graph.node_copy(producer_node, load_arg)
    graph_module = GraphModule(root, graph)
    return graph_module
    def replace_observer_with_quantize_dequantize_node(
            graph: Graph, node: Node, modules: Dict[str,
                                                    torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        ... -> observer_0(x) -> ...
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        assert modules is not None
        assert isinstance(, str)
        observer_module = modules[]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
        elif observer_module.dtype in [
                torch.quint8, torch.qint8, torch.float16
            node_type, quantize_op, qparams = get_quantize_node_info(
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            root_module, graph, key, value)
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
Beispiel #28
def create_getattr_from_value(module: torch.nn.Module, graph: Graph,
                              prefix: str, value: Any) -> Node:
    Given a value of any type, creates a getattr node corresponding to the value and
    registers the value as a buffer to the module.
    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
    attr_name = get_new_attr_name(module)
    device = assert_and_get_unique_device(module)
    new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
        else torch.tensor(value, device=device)
    module.register_buffer(attr_name, new_value)
    # Create get_attr with value
    attr_node = graph.create_node("get_attr", attr_name)
    return attr_node
Beispiel #29
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))

        if node.op == 'call_module' and is_activation_post_process(
            # remove activation post process node
            env[] = env[node.args[0].name]

        elif node in node_to_instrument_to_ref_node_name:
            other_node_name = node_to_instrument_to_ref_node_name[node]
            # ensure env is populated with base node
            env[] = new_graph.node_copy(node, load_arg)
            # add the logger after the base node
            env[] = _insert_logger_after_node(env[], gm,

            env[] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Beispiel #30
    def convert(self, observed, inplace=False, debug=False):
        assert self.activation_post_process_map is not None
        # move to cpu since we only have quantized cpu kernels
        observed_root = observed.root
        observed_graph = observed.graph
        if not inplace:
            observed_root = copy.deepcopy(observed_root)
        self.modules = dict(observed_root.named_modules())

        matches = self._find_matches(observed.graph, self.modules,
        quants = self._find_quants(observed.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if not in env:
                assert in quant_env, \
                    'trying to load float node but did not find node:' + + \
                    ' in quantized environment:' + str(quant_env)
                env[] = Proxy(quant_env[]).dequantize().node
            return env[]

        def load_quantized(n):
            if not in quant_env:
                assert in env, \
                    'trying to load quantized node but did not find node:' + + \
                    ' in float environment:' + str(env)
                assert in quants, 'did not find quant object for node:' +
                quant = quants[][0]
                quant_env[] = quant.convert(self, env[])
            return quant_env[]

        def load_x(n):
            assert in env or in quant_env, \
                'node ' + + ' does not exist in either of the environment'
            if in quant_env:
                return quant_env[]
                return env[]

        def load_arg(quantized):
            if quantized is a list, then arg should be a list and the args with corresponding
            indexes will be quantized
            if quantized is a boolean, then all args will be quantized/not quantized
            if quantized is None, then we'll load the node as long as it exists
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg):
                if quantized is None:
                    return map_arg(arg, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg, (tuple, list)), arg
                    loaded_arg = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg):
                        if i in quantized:
                            loaded_arg.append(map_arg(a, load_quantized))
                            loaded_arg.append(map_arg(a, load_non_quantized))
                    return type(arg)(loaded_arg)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert in env or in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if in quant_env:
                    return True
                elif in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in observed_graph.nodes:
            root_node, matched, obj, qconfig = matches.get(
      , (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                if self.quant_type == QuantType.DYNAMIC:
                    quantized = False

                if quantized:
                    quant_env[] = result
                    env[] = result
            elif root_node is not None:

            # handle activation post process calls
            if node.op == 'call_module':
                    observer_module = self.modules[]
                    prev_node = node.args[0]
                    if in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[] = quant_env[]
                    # replace activation post process with quantization ops
                    parent_name = ''

                    scale, zero_point = observer_module.calculate_qparams()
                    # TODO: per channel
                    scale = float(scale)
                    zero_point = int(zero_point)
                    dtype = observer_module.dtype
                    qparams = {
                        '_scale_': scale,
                        '_zero_point_': zero_point,
                        '_dtype_': dtype
                    i = 0

                    def noattr(module, qparams, i):
                        for name in qparams.keys():
                            if hasattr(module, name + str(i)):
                                return False
                        return True

                    def get_next_i(module, qparams):
                        i = 0
                        while not noattr(module, qparams, i):
                            i += 1
                        return i

                    parent_module = self.modules[parent_name]
                    i = get_next_i(parent_module, qparams)
                    inputs = [load_non_quantized(node.args[0])]
                    for key, value in qparams.items():
                        setattr(parent_module, key + str(i), value)
                        qparam_full_path = key + str(i)
                        if parent_name:
                            qparam_full_path = parent_name + '.' + qparam_full_path
                                'get_param', qparam_full_path))
                    quant_env[] = self.quantized_graph.create_node(
                        'call_function', torch.quantize_per_tensor, inputs, {})
            # dequantize inputs for the node that are not quantized
            env[] = self.quantized_graph.node_copy(
                node, load_non_quantized)


        to_be_removed = []
        for name, _ in observed_root.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
        for n in to_be_removed:
            delattr(observed_root, n)
        return GraphModule(observed_root, self.quantized_graph)