class Quantizer: def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): self.root = mod self.graph = mod.graph self.quant_ctor = quant_ctor # cached information for observe self.state_dict = self.root.state_dict() self.modules = dict(self.root.named_modules()) # match the patterns that will get quantized self.matches = self._find_matches(patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an quant_ctor object for each self.quants = self._find_quants(quant_ctor) def observe(self, args): # most of this function is just an interpreter for the graph # it would be possible to put this in some abstraction, but # it is pretty nice to just be able to see exactly what is happening here # and hack on it. # maybe we should just provide an example interpreter that people copy/paste # then edit. args_iter = iter(args) env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) elif node.op == 'get_attr': result = self.state_dict[node.target] elif node.op == 'call_function': result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) elif node.op == 'call_method': self_obj, *args = load_arg(node.args) kwargs = load_arg(node.kwargs) result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) env[node.name] = result root_node, obj = self.matches.get(node.name, (None, None)) if root_node is node: obj.observe(node, env) if node.name in self.quants: self.quants[node.name].observe(node, env) return load_arg(self.graph.result) def quantize(self): self.quantized_graph = Graph() env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph) def _find_matches(self, patterns): modules = dict(self.root.named_modules()) match_map = {} # node name -> (root_node, match_value?) def apply_match(pattern, node, match): if isinstance(pattern, tuple): s, *args = pattern apply_match(s, node, match) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: match_map[node.name] = match for node in reversed(self.graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): if matches(modules, node, pattern): apply_match(pattern, node, (node, value(self, node))) return match_map def _find_quants(self, quant_ctor): quants = {} def visit_arg(n): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) if n.name not in quants: quants[n.name] = quant_ctor(self, n) for node in self.graph.nodes: if node.name in self.matches: map_arg(node.args, visit_arg) map_arg(node.kwargs, visit_arg) return quants
def call(self, graph_module: GraphModule) -> PassResult: """ Return a new copy of torch.fx.GraphModule with CSE applied to the input graph Example usage: from torch.fx.experimental.proxy_tensor import make_fx def f(a): b = a * a c = a * a return b+c p = CSEPass() traced_graph = make_fx(f)(torch.tensor(1)) print(traced_graph) result = p(traced_graph) print(result.graph_module) """ def get_aten_target(node): if hasattr(node.target, 'overloadpacket'): return node.target.overloadpacket return node.target modified = False new_graph = Graph() env: Dict[Node, Node] = { } # map from node in the old graph to node in the new graph hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new grpah without change # do not CSE away random operations if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target( n) in self.banned_ops: new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs memebrs to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted token = { "target": n.target, "args": args, "args_spec": args_spec, "kwargs": kwargs, "kwargs_spec": kwargs_spec } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) hash_val = (n.target, hash_arg) # check if a node has a substitute and can be eliminated hash_val_in_hash_env = hash_val in hash_env if hash_val_in_hash_env and token_map[hash_val] == token: modified = True # substition happens and the graph is modified env[n] = hash_env[hash_val] continue new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node if not hash_val_in_hash_env: hash_env[hash_val] = new_node token_map[hash_val] = token csed_gm = GraphModule(graph_module, new_graph) return PassResult(csed_gm, modified)