def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args)
def _find_quants(self, quant_ctor): quants = {} def visit_arg(n): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) if n.name not in quants: quants[n.name] = quant_ctor(self, n) for node in self.graph.nodes: if node.name in self.matches: map_arg(node.args, visit_arg) map_arg(node.kwargs, visit_arg) return quants
def replace_target_nodes_with( fx_module: GraphModule, old_op: str, old_target: Target, new_op: str, new_target: Target, ): """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, and updates them to match the new op code and target""" new_graph = Graph() val_map: Dict[Node, Node] = {} for node in fx_module.graph.nodes: if node.op == old_op and node.target == old_target: args = map_arg(node.args, lambda n: val_map[n]) kwargs = map_arg(node.kwargs, lambda n: val_map[n]) assert isinstance(args, tuple) assert isinstance(kwargs, dict) val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name) else: val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) new_graph.output(map_arg(fx_module.graph.result, lambda n: val_map[n])) fx_module.graph = new_graph
def quantize(self): self.quantized_graph = Graph() env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph)
def _find_quants(self, graph, matches): """ Takes the nodes in the input graph and pending matches, and finds and returns the input and output nodes which need to be quantized. Inputs: - graph: an fx.Graph object - matches: output of self._find_matches function Outputs a map of node_name -> (QuantizeHandler instance (always DefaultQuant), qconfig) """ quants = {} def visit(node, qconfig): def visit_arg(arg): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]: is_weight = True if (not self.is_dynamic_quant) or is_weight: # overwrite previous quant config quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) return visit_arg for node in graph.nodes: if node.name in matches: root_node, matched, obj, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode if isinstance(obj, CopyNode): qconfig = None if root_node is node: # matched[-1] is the first op in the sequence and # matched[0] is the last op in the sequence # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) # output map_arg(matched[0], visit(None, qconfig)) return quants
def load_arg(a): return map_arg(a, lambda node: env[node.name])
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False): self.restore_state(model) if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) matches = self._find_matches(model.graph, self.modules, self.patterns) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized or non quantized environment, env: ' + str(env) + \ ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception("partially quantized inputs in list not handled yet") for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy(node, load_non_quantized) quantized = False else: result = obj.convert(self, node, load_arg) # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) else: quantized = True # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if is_activation_post_process(self.modules[node.target]): observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) continue if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops root_module = self.modules[''] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized)) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg)) module_dict = dict(model.named_modules()) to_be_removed = [] for name, module in model.named_modules(): if is_activation_post_process(module) and not is_submodule_of_fake_quant(name, module, module_dict): to_be_removed.append(name) for n in to_be_removed: delattr(model, n) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model