Пример #1
0
 def flatten(x: torch.fx.node.Argument) -> NodeList:
     """
     Stores nodes in x to a list and returns the list.
     """
     r: NodeList = []
     map_arg(x, r.append)
     return r
Пример #2
0
 def load_arg_impl(arg_or_args):
     if quantized is None:
         return map_arg(arg_or_args, load_x)
     if isinstance(quantized, bool):
         return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
     elif isinstance(quantized, (tuple, list)):
         assert isinstance(arg_or_args, (tuple, list)), arg_or_args
         loaded_args = []
         # for now, we only support quantizing positional arguments
         for i, a in enumerate(arg_or_args):
             if i in quantized:
                 loaded_args.append(map_arg(a, load_quantized))
             else:
                 loaded_args.append(map_arg(a, load_non_quantized))
         return type(arg_or_args)(loaded_args)
Пример #3
0
    def _find_quants(self, quant_ctor):
        quants = {}

        def visit_arg(n):
            # note: we have to measure quantization information
            # even for nodes where we might not use it because it is already
            # quantized. This is because each match has the option to
            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
            if n.name not in quants:
                quants[n.name] = quant_ctor(self, n)

        for node in self.graph.nodes:
            if node.name in self.matches:
                map_arg(node.args, visit_arg)
                map_arg(node.kwargs, visit_arg)
        return quants
Пример #4
0
    def _find_quants(self, graph, matches):
        quants = {}

        def visit(node, qconfig):
            def visit_arg(arg):
                # note: we have to measure quantization information
                # even for nodes where we might not use it because it is already
                # quantized. This is because each match has the option to
                # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
                is_weight = False
                if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
                    for i, node_arg in enumerate(node.args):
                        if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
                            is_weight = True
                if self.quant_type != QuantType.DYNAMIC or is_weight:
                    # overwrite previous quant config
                    quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
            return visit_arg

        for node in graph.nodes:
            if node.name in matches:
                root_node, matched, obj, qconfig = matches[node.name]
                # don't attach observer/fake_quant for CopyNode
                if isinstance(obj, CopyNode):
                    qconfig = None
                if root_node is node:
                    # matched[-1] is the first op in the sequence and
                    # matched[0] is the last op in the sequence
                    # inputs
                    map_arg(matched[-1].args, visit(matched[-1], qconfig))
                    map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
                    # output
                    map_arg(matched[0], visit(None, qconfig))
        return quants
Пример #5
0
def replace_target_nodes_with(
    fx_module: GraphModule,
    old_op: str,
    old_target: Target,
    new_op: str,
    new_target: Target,
):
    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
    and updates them to match the new op code and target"""
    new_graph = Graph()
    val_map: Dict[Node, Node] = {}
    for node in fx_module.graph.nodes:
        if node.op == old_op and node.target == old_target:
            args = map_arg(node.args, lambda n: val_map[n])
            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
            assert isinstance(args, tuple)
            assert isinstance(kwargs, dict)
            val_map[node] = new_graph.create_node(new_op, new_target, args,
                                                  kwargs, node.name)
        else:
            val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
    fx_module.graph = new_graph
Пример #6
0
    def quantize(self):
        self.quantized_graph = Graph()
        self.delegate = DelegateBase(self.quantized_graph)

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)
Пример #7
0
    def _find_quants(self, graph, matches):
        """
        Takes the nodes in the input graph and pending matches, and finds and
        returns the input and output nodes which need to be quantized.

        Inputs:
          - graph: an fx.Graph object
          - matches: output of self._find_matches function

        Outputs a map of
          node_name -> (QuantizeHandler instance (always DefaultQuant), qconfig)
        """
        quants = {}

        def visit(node, qconfig):
            def visit_arg(arg):
                # note: we have to measure quantization information
                # even for nodes where we might not use it because it is already
                # quantized. This is because each match has the option to
                # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
                is_weight = False
                if isinstance(
                        node, Node
                ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
                    for i, node_arg in enumerate(node.args):
                        if arg is node_arg and i in WEIGHT_INDEX_DICT[
                                node.target]:
                            is_weight = True
                if qconfig is not None and \
                   (activation_is_statically_quantized(qconfig) or is_weight):
                    # overwrite previous quant config
                    quants[arg.name] = (DefaultQuant(self,
                                                     arg), qconfig, is_weight)

            return visit_arg

        for node in graph.nodes:
            if node.name in matches:
                root_node, matched, obj, qconfig = matches[node.name]
                # don't attach observer/fake_quant for CopyNode
                if isinstance(obj, CopyNode):
                    qconfig = None
                if root_node is node:
                    # matched[-1] is the first op in the sequence and
                    # matched[0] is the last op in the sequence
                    # inputs
                    map_arg(matched[-1].args, visit(matched[-1], qconfig))
                    map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
                    # output
                    if isinstance(obj, StandaloneModuleQuantizeHandler):
                        # we don't insert observer for output of custom
                        # module
                        continue
                    map_arg(matched[0], visit(None, qconfig))
        return quants
Пример #8
0
 def load_arg(a):
     return map_arg(a, lambda node: env[node.name])
Пример #9
0
    def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False):
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        matches = self._find_matches(model.graph, self.modules, self.patterns)
        quants = self._find_quants(model.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + str(env) + \
                    ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception("partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(node, load_non_quantized)
                    quantized = False
                else:
                    result = obj.convert(self, node, load_arg)
                    # Need to get correct quantized/non-quantized state for the output of CopyNode
                    if isinstance(obj, CopyNode):
                        assert node.op in [
                            'call_module',
                            'call_function',
                            'call_method'], \
                            'CopyNode of type ' + node.op + ' is not handled'
                        quantized = is_quantized(node.args[0])
                    else:
                        quantized = True

                    # output of dynamic quantization is not quantized
                    if self.is_dynamic_quant:
                        quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith('activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if observer_module.dtype == torch.float16:
                        # activations are not quantized for
                        # fp16 dynamic quantization
                        # copy the activaiton_post_process node here
                        # since we may need it when we insert prepack
                        # op for weight of linear, this will be removed
                        # later in a separate pass
                        env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
                        continue
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
        self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized))

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'call_module' and \
               node.target.split('.')[-1].startswith('activation_post_process_'):
                # remove activation post process
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg)
        act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg))

        to_be_removed = []
        for name, _ in model.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Пример #10
0
def split_by_tags(gm: torch.fx.GraphModule,
                  tags: List[str]) -> torch.fx.GraphModule:
    """
    Splits a GraphModule using tags on its graph nodes. We honor the order of
    tags. For example, we have tags = ["a", "b", "c"], the function will create
    the initial submodules in the order of "a_0", "b_1", "c_2".

    To set a tag:
    gm.graph.nodes[idx].tag = "mytag"

    This will result in all nodes with the same tag being extracted and placed in their
    own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
    and output nodes are created when needed while get_attr nodes get copied to submodules
    where they are used.

    Given the following module def:

    class SimpleModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = torch.nn.Linear(...)
            self.linear2 = torch.nn.Linear(...)
            self.linear3 = torch.nn.Linear(...)

        def forward(self, in1, in2):
            r1 = self.linear1(in1)
            r2 = self.linear2(in2)
            r3 = torch.cat([r1, r2])
            return self.linear3(r3)

    Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:

    ro_0:
    def forward(self, in1):
        self = self.root
        linear1 = self.linear1(in1)
        return linear1

    main_1:
    def forward(self, in2, linear1):
        self = self.root
        linear2 = self.linear2(in2)
        cat_1 = torch.cat([linear1, linear2])
        linear3 = self.linear3(cat_1)
        return linear3

    main_0:
    def forward(self, in1, in2):
        self = self.root
        ro_0 = self.ro_0(in1)
        main_1 = self.main_1(in2, ro_0)
        return main_1
    """
    def flatten(x: torch.fx.node.Argument) -> NodeList:
        """
        Stores nodes in x to a list and returns the list.
        """
        r: NodeList = []
        map_arg(x, r.append)
        return r

    # Mapping from node in original module to node in created submodule.
    node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}

    # Mapping from node in orignal module or created submodules to
    # corresponding component.
    node_to_component: Dict[torch.fx.Node, Component] = {}

    # Mapping from tag to the corresponding component.
    tag_to_component: Dict[str, Component] = {}

    # Stores all components.
    all_components: List[Component] = []

    # Stores nodes that will be used in main graph.
    used_in_main: NodeSet = set()

    # Main graph after split.
    main_g = torch.fx.Graph()

    # Mapping from node in original module to node in main graph after split.
    main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}

    # Output node of original module.
    output_node: Optional[torch.fx.Node] = None

    # Create a component for each tag, we don't expect to create other components afterwards.
    for tag in tags:
        comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
        all_components.append(comp)
        tag_to_component[tag] = comp

    # Traverse the nodes in original graph and take care of them.
    for node in gm.graph.nodes:
        if node.op == "output":
            if output_node is not None:
                raise RuntimeError("Multiple output nodes in graph!")
            output_node = node
            continue

        # Placeholders in the original graph get copied to main graph.
        if node.op == "placeholder":
            main_remapping[node] = main_g.placeholder(node.name,
                                                      type_expr=node.type)
            continue

        # Get_attr nodes are ignored because we are not tagging them.
        # Instead, we copy them directly to the submodules use them afterwards.
        if node.op == "get_attr":
            continue

        # Now we process callable nodes which are nodes with op of call_module,
        # call_function or call_method. Every callable nodes should be tagged.
        assert hasattr(node, "tag")

        upstream_components = [
            node_to_component[x]
            for x in flatten(node.args) + flatten(node.kwargs)
            if x.op not in {"placeholder", "get_attr"}
        ]

        comp = tag_to_component[node.tag]
        node_to_component[node] = comp

        # Max order of upperstream components.
        mx = max((c.order for c in upstream_components), default=0)

        # Expect the componet for `node` has higher order then its upstream components.
        assert comp.order >= mx

        # Map a input of `node` to nodes in the component's graph.
        def remap_func(x):
            # If input is a get_attr node, copy it to current component's graph.
            # Returns the get_attr node in current component's graph.
            if x.op == "get_attr":
                if x not in comp.getattr_maps:
                    comp.getattr_maps[x] = comp.graph.get_attr(
                        x.target, type_expr=x.type)
                return comp.getattr_maps[x]

            # If input is not a placeholder, it should have been put into a component
            # already. If it's the current component then we return the correcspoding
            # node in the component.
            if x.op != "placeholder" and node_to_component[x] == comp:
                return node_remapping[x]

            # If input is a placeholder or it's in other components, we want to make it
            # as a placeholder in current component's graph.
            if x not in comp.orig_inputs:
                comp.orig_inputs.append(x)
                comp.input_placeholders.append(
                    comp.graph.placeholder(x.name, type_expr=x.type))
                used_in_main.add(x)

            return comp.input_placeholders[next(
                i for i, y in enumerate(comp.orig_inputs) if x is y)]

        n = comp.graph.node_copy(node, remap_func)
        n.tag = node.tag  # type: ignore[attr-defined]
        node_remapping[node] = n
        node_to_component[n] = comp

    if output_node is None:
        raise RuntimeError("Graph had no output node!")

    for x in flatten(output_node.args[0]):
        used_in_main.add(x)

    # If a node is used in main graph then we mark it as an output in the component
    # it belongs to.
    for n in used_in_main:
        if n.op != "placeholder":
            node_to_component[n].orig_outputs.append(n)

    # Now we create a graphmodule for each component.
    for comp in all_components:
        outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))

        # Take care of the args of FX output node. If there's a single
        # output then the output node args is like (output_single), else
        # if there're multiple outputs then the output node args is like
        # ((output_0, output_1, ...)).
        comp.graph.output(outs[0] if len(outs) == 1 else outs)

        # Loop through all module calls (call_module) and param feteches (get_attr)
        # in this component, creating HolderModules as necessary to match the path.
        # e.g. if in the original module there's a get_attr node fetches "conv.weight".
        # We create a HolderModule as root -> add a HolderModule named "conv" ->
        # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
        # the original module.
        root = HolderModule({})
        for n in comp.graph.nodes:
            if n.op not in ("call_module", "get_attr"):
                continue

            target = n.target
            assert isinstance(target, str)
            target_name_parts = target.split(".")
            curr = root
            orig_gm = gm

            for name in target_name_parts[:-1]:
                if not hasattr(curr, name):
                    curr.add_module(name, HolderModule({}))

                curr = getattr(curr, name)
                orig_gm = getattr(orig_gm, name)

            leaf_node_name = target_name_parts[-1]
            leaf_node = getattr(orig_gm, leaf_node_name)

            # Relies on custom __setattr__ magic.
            setattr(curr, leaf_node_name, leaf_node)

        comp.gm = torch.fx.GraphModule(root, comp.graph)

        # Create a call_module node in main graph.
        main_node = main_g.call_module(
            comp.name,
            args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
            kwargs=None,
        )

        if len(outs) == 1:
            main_remapping[comp.orig_outputs[0]] = main_node
        else:
            for i, o in enumerate(comp.orig_outputs):
                # Use Proxy to record getitem access.
                main_remapping[o] = torch.fx.Proxy(
                    main_node)[i].node  # type: ignore[index]

    main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
    main_root = HolderModule({comp.name: comp.gm for comp in all_components})

    return torch.fx.GraphModule(main_root, main_g)