Пример #1
0
class Quantizer:
    def __init__(self,
                 mod,
                 patterns=DEFAULT_QUANTIZATION_PATTERNS,
                 quant_ctor=DefaultQuant):
        self.root = mod
        self.graph = mod.graph
        self.quant_ctor = quant_ctor

        # cached information for observe
        self.state_dict = self.root.state_dict()
        self.modules = dict(self.root.named_modules())

        # match the patterns that will get quantized
        self.matches = self._find_matches(patterns)
        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an quant_ctor object for each
        self.quants = self._find_quants(quant_ctor)

    def observe(self, args):
        # most of this function is just an interpreter for the graph
        # it would be possible to put this in some abstraction, but
        # it is pretty nice to just be able to see exactly what is happening here
        # and hack on it.
        # maybe we should just provide an example interpreter that people copy/paste
        # then edit.
        args_iter = iter(args)
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = self.state_dict[node.target]
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args),
                                     **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args),
                                                   **load_arg(node.kwargs))

            env[node.name] = result
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is node:
                obj.observe(node, env)
            if node.name in self.quants:
                self.quants[node.name].observe(node, env)

        return load_arg(self.graph.result)

    def quantize(self):
        self.quantized_graph = Graph()

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)

    def _find_matches(self, patterns):
        modules = dict(self.root.named_modules())
        match_map = {}  # node name -> (root_node, match_value?)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
            else:
                match_map[node.name] = match

        for node in reversed(self.graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    if matches(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map

    def _find_quants(self, quant_ctor):
        quants = {}

        def visit_arg(n):
            # note: we have to measure quantization information
            # even for nodes where we might not use it because it is already
            # quantized. This is because each match has the option to
            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
            if n.name not in quants:
                quants[n.name] = quant_ctor(self, n)

        for node in self.graph.nodes:
            if node.name in self.matches:
                map_arg(node.args, visit_arg)
                map_arg(node.kwargs, visit_arg)
        return quants
Пример #2
0
    def call(self, graph_module: GraphModule) -> PassResult:
        """
        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph

        Example usage:

        from torch.fx.experimental.proxy_tensor import make_fx
        def f(a):
            b = a * a
            c = a * a
            return b+c

        p = CSEPass()
        traced_graph = make_fx(f)(torch.tensor(1))
        print(traced_graph)
        result = p(traced_graph)
        print(result.graph_module)
        """
        def get_aten_target(node):
            if hasattr(node.target, 'overloadpacket'):
                return node.target.overloadpacket
            return node.target

        modified = False
        new_graph = Graph()
        env: Dict[Node, Node] = {
        }  # map from node in the old graph to node in the new graph
        hash_env: Dict[Tuple[torch._ops.OpOverload, int],
                       Node] = {}  # map from hash to a node in the new graph
        token_map: Dict[Tuple[torch._ops.OpOverload, int],
                        Dict[str, Any]] = {}  # map from hash to token
        for n in graph_module.graph.nodes:
            # The placeholder, output, and get_attr nodes are copied to the new grpah without change
            # do not CSE away random operations
            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(
                    n) in self.banned_ops:
                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
                # substitute args and kwargs memebrs to their mapping in env if exists
                # specs can be used to reconstruct nested list/dictionaries
                def substitute(arg_list):
                    arg_list, spec = tree_flatten(arg_list)
                    for i in range(len(arg_list)):
                        v = arg_list[i]
                        if isinstance(v, Node) and v in env:
                            arg_list[i] = env[v]
                    return tuple(arg_list), spec

                args, args_spec = substitute(n.args)
                kwargs, kwargs_spec = substitute(n.kwargs)

                # each token corresponds to a unique node
                # nodes with the same token can be substituted
                token = {
                    "target": n.target,
                    "args": args,
                    "args_spec": args_spec,
                    "kwargs": kwargs,
                    "kwargs_spec": kwargs_spec
                }

                # hash substituted args to a number, do not hash specs because specs are not hashable
                hash_arg = hash((args, kwargs))
                hash_val = (n.target, hash_arg)

                # check if a node has a substitute and can be eliminated
                hash_val_in_hash_env = hash_val in hash_env
                if hash_val_in_hash_env and token_map[hash_val] == token:
                    modified = True  # substition happens and the graph is modified
                    env[n] = hash_env[hash_val]
                    continue

                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
                if not hash_val_in_hash_env:
                    hash_env[hash_val] = new_node
                    token_map[hash_val] = token

        csed_gm = GraphModule(graph_module, new_graph)
        return PassResult(csed_gm, modified)