def flatten(x: torch.fx.node.Argument) -> NodeList: """ Stores nodes in x to a list and returns the list. """ r: NodeList = [] map_arg(x, r.append) return r
def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args)
def _find_quants(self, quant_ctor): quants = {} def visit_arg(n): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) if n.name not in quants: quants[n.name] = quant_ctor(self, n) for node in self.graph.nodes: if node.name in self.matches: map_arg(node.args, visit_arg) map_arg(node.kwargs, visit_arg) return quants
def _find_quants(self, graph, matches): quants = {} def visit(node, qconfig): def visit_arg(arg): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]: is_weight = True if self.quant_type != QuantType.DYNAMIC or is_weight: # overwrite previous quant config quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) return visit_arg for node in graph.nodes: if node.name in matches: root_node, matched, obj, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode if isinstance(obj, CopyNode): qconfig = None if root_node is node: # matched[-1] is the first op in the sequence and # matched[0] is the last op in the sequence # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) # output map_arg(matched[0], visit(None, qconfig)) return quants
def replace_target_nodes_with( fx_module: GraphModule, old_op: str, old_target: Target, new_op: str, new_target: Target, ): """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, and updates them to match the new op code and target""" new_graph = Graph() val_map: Dict[Node, Node] = {} for node in fx_module.graph.nodes: if node.op == old_op and node.target == old_target: args = map_arg(node.args, lambda n: val_map[n]) kwargs = map_arg(node.kwargs, lambda n: val_map[n]) assert isinstance(args, tuple) assert isinstance(kwargs, dict) val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name) else: val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) fx_module.graph = new_graph
def quantize(self): self.quantized_graph = Graph() self.delegate = DelegateBase(self.quantized_graph) env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph)
def _find_quants(self, graph, matches): """ Takes the nodes in the input graph and pending matches, and finds and returns the input and output nodes which need to be quantized. Inputs: - graph: an fx.Graph object - matches: output of self._find_matches function Outputs a map of node_name -> (QuantizeHandler instance (always DefaultQuant), qconfig) """ quants = {} def visit(node, qconfig): def visit_arg(arg): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance( node, Node ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if arg is node_arg and i in WEIGHT_INDEX_DICT[ node.target]: is_weight = True if qconfig is not None and \ (activation_is_statically_quantized(qconfig) or is_weight): # overwrite previous quant config quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) return visit_arg for node in graph.nodes: if node.name in matches: root_node, matched, obj, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode if isinstance(obj, CopyNode): qconfig = None if root_node is node: # matched[-1] is the first op in the sequence and # matched[0] is the last op in the sequence # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) # output if isinstance(obj, StandaloneModuleQuantizeHandler): # we don't insert observer for output of custom # module continue map_arg(matched[0], visit(None, qconfig)) return quants
def load_arg(a): return map_arg(a, lambda node: env[node.name])
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False): self.restore_state(model) if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) matches = self._find_matches(model.graph, self.modules, self.patterns) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized or non quantized environment, env: ' + str(env) + \ ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception("partially quantized inputs in list not handled yet") for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy(node, load_non_quantized) quantized = False else: result = obj.convert(self, node, load_arg) # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) else: quantized = True # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith('activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) continue if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops root_module = self.modules[''] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized)) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'call_module' and \ node.target.split('.')[-1].startswith('activation_post_process_'): # remove activation post process env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg)) to_be_removed = [] for name, _ in model.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(model, n) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule: """ Splits a GraphModule using tags on its graph nodes. We honor the order of tags. For example, we have tags = ["a", "b", "c"], the function will create the initial submodules in the order of "a_0", "b_1", "c_2". To set a tag: gm.graph.nodes[idx].tag = "mytag" This will result in all nodes with the same tag being extracted and placed in their own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder and output nodes are created when needed while get_attr nodes get copied to submodules where they are used. Given the following module def: class SimpleModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(...) self.linear2 = torch.nn.Linear(...) self.linear3 = torch.nn.Linear(...) def forward(self, in1, in2): r1 = self.linear1(in1) r2 = self.linear2(in2) r3 = torch.cat([r1, r2]) return self.linear3(r3) Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: ro_0: def forward(self, in1): self = self.root linear1 = self.linear1(in1) return linear1 main_1: def forward(self, in2, linear1): self = self.root linear2 = self.linear2(in2) cat_1 = torch.cat([linear1, linear2]) linear3 = self.linear3(cat_1) return linear3 main_0: def forward(self, in1, in2): self = self.root ro_0 = self.ro_0(in1) main_1 = self.main_1(in2, ro_0) return main_1 """ def flatten(x: torch.fx.node.Argument) -> NodeList: """ Stores nodes in x to a list and returns the list. """ r: NodeList = [] map_arg(x, r.append) return r # Mapping from node in original module to node in created submodule. node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} # Mapping from node in orignal module or created submodules to # corresponding component. node_to_component: Dict[torch.fx.Node, Component] = {} # Mapping from tag to the corresponding component. tag_to_component: Dict[str, Component] = {} # Stores all components. all_components: List[Component] = [] # Stores nodes that will be used in main graph. used_in_main: NodeSet = set() # Main graph after split. main_g = torch.fx.Graph() # Mapping from node in original module to node in main graph after split. main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} # Output node of original module. output_node: Optional[torch.fx.Node] = None # Create a component for each tag, we don't expect to create other components afterwards. for tag in tags: comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") all_components.append(comp) tag_to_component[tag] = comp # Traverse the nodes in original graph and take care of them. for node in gm.graph.nodes: if node.op == "output": if output_node is not None: raise RuntimeError("Multiple output nodes in graph!") output_node = node continue # Placeholders in the original graph get copied to main graph. if node.op == "placeholder": main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) continue # Get_attr nodes are ignored because we are not tagging them. # Instead, we copy them directly to the submodules use them afterwards. if node.op == "get_attr": continue # Now we process callable nodes which are nodes with op of call_module, # call_function or call_method. Every callable nodes should be tagged. assert hasattr(node, "tag") upstream_components = [ node_to_component[x] for x in flatten(node.args) + flatten(node.kwargs) if x.op not in {"placeholder", "get_attr"} ] comp = tag_to_component[node.tag] node_to_component[node] = comp # Max order of upperstream components. mx = max((c.order for c in upstream_components), default=0) # Expect the componet for `node` has higher order then its upstream components. assert comp.order >= mx # Map a input of `node` to nodes in the component's graph. def remap_func(x): # If input is a get_attr node, copy it to current component's graph. # Returns the get_attr node in current component's graph. if x.op == "get_attr": if x not in comp.getattr_maps: comp.getattr_maps[x] = comp.graph.get_attr( x.target, type_expr=x.type) return comp.getattr_maps[x] # If input is not a placeholder, it should have been put into a component # already. If it's the current component then we return the correcspoding # node in the component. if x.op != "placeholder" and node_to_component[x] == comp: return node_remapping[x] # If input is a placeholder or it's in other components, we want to make it # as a placeholder in current component's graph. if x not in comp.orig_inputs: comp.orig_inputs.append(x) comp.input_placeholders.append( comp.graph.placeholder(x.name, type_expr=x.type)) used_in_main.add(x) return comp.input_placeholders[next( i for i, y in enumerate(comp.orig_inputs) if x is y)] n = comp.graph.node_copy(node, remap_func) n.tag = node.tag # type: ignore[attr-defined] node_remapping[node] = n node_to_component[n] = comp if output_node is None: raise RuntimeError("Graph had no output node!") for x in flatten(output_node.args[0]): used_in_main.add(x) # If a node is used in main graph then we mark it as an output in the component # it belongs to. for n in used_in_main: if n.op != "placeholder": node_to_component[n].orig_outputs.append(n) # Now we create a graphmodule for each component. for comp in all_components: outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) # Take care of the args of FX output node. If there's a single # output then the output node args is like (output_single), else # if there're multiple outputs then the output node args is like # ((output_0, output_1, ...)). comp.graph.output(outs[0] if len(outs) == 1 else outs) # Loop through all module calls (call_module) and param feteches (get_attr) # in this component, creating HolderModules as necessary to match the path. # e.g. if in the original module there's a get_attr node fetches "conv.weight". # We create a HolderModule as root -> add a HolderModule named "conv" -> # make "weight" a attribute of "conv" HolderModule and point to conv.weight in # the original module. root = HolderModule({}) for n in comp.graph.nodes: if n.op not in ("call_module", "get_attr"): continue target = n.target assert isinstance(target, str) target_name_parts = target.split(".") curr = root orig_gm = gm for name in target_name_parts[:-1]: if not hasattr(curr, name): curr.add_module(name, HolderModule({})) curr = getattr(curr, name) orig_gm = getattr(orig_gm, name) leaf_node_name = target_name_parts[-1] leaf_node = getattr(orig_gm, leaf_node_name) # Relies on custom __setattr__ magic. setattr(curr, leaf_node_name, leaf_node) comp.gm = torch.fx.GraphModule(root, comp.graph) # Create a call_module node in main graph. main_node = main_g.call_module( comp.name, args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), kwargs=None, ) if len(outs) == 1: main_remapping[comp.orig_outputs[0]] = main_node else: for i, o in enumerate(comp.orig_outputs): # Use Proxy to record getitem access. main_remapping[o] = torch.fx.Proxy( main_node)[i].node # type: ignore[index] main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) main_root = HolderModule({comp.name: comp.gm for comp in all_components}) return torch.fx.GraphModule(main_root, main_g)