def fuse(self, model, inplace=False): if not inplace: model = copy.deepcopy(model) input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) fusion_patterns = get_fusion_patterns() # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here self.fused_graph.output(load_arg(input_graph.result)) model = GraphModule(input_root, self.fused_graph) return model
def transform(traced): new_graph = torch._fx.Graph() new_graph.graph_copy(traced.graph) relu_out = new_graph.create_node(op='call_method', target='neg', args=(new_graph.nodes[-1], ), kwargs={}) new_graph.output(relu_out) return GraphModule(traced, new_graph)
def test_graph_edit_with_proxy(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch._fx.Graph() new_g.graph_copy(g) t = Proxy(new_g.nodes[-1]) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) gm.graph.lint(gm) self.assertEqual(gm(3, 4), 14)
def quantize(self): self.quantized_graph = Graph() env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph)
def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a, )) c = g.get_attr('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d, )) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) gm.graph.lint(gm) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref)
def test_graph_unique_names(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch._fx.Graph() new_g.graph_copy(g) t = Proxy(new_g.nodes[-1]) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) seen_names: Set[str] = set() for node in gm.graph.nodes: assert node.name not in seen_names seen_names.add(node.name)
def _fold_weight(self, quantized): packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_') quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node( 'get_attr', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) folded_graph.output(load_arg(quantized_graph.result)) quantized = GraphModule(quantized_root, folded_graph) return quantized
def graph_module_from_producer_nodes(root, producer_nodes): r''' Construct a graph module from extracted producer nodes from `collect_producer_nodes` function Args: root: the root module for the original graph producer_nodes: a list of nodes we use to construct the graph Return: A graph module constructed from the producer nodes ''' assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' # since we traced back from node to getattrr producer_nodes.reverse() graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node]) for producer_node in producer_nodes: env[producer_node] = graph.node_copy(producer_node, load_arg) graph.output(load_arg(producer_nodes[-1])) graph_module = GraphModule(root, graph) return graph_module
def lower_to_elementwise_interpreter( orig_mod: torch.nn.Module) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) # ===== Stage 2: Lower GraphModule representation to the C++ # interpreter's instruction format ====== instructions = [] constant_idx = 0 constants = {} fn_input_names = [] target_to_name = {operator.add: "add", operator.mul: "mul"} # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter for n in mod.graph.nodes: target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" if n.op == 'placeholder': # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) elif n.op == 'call_function': assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() arg_name = f'constant_{constant_idx}' constants[arg_name] = torch.Tensor( [arg] if isinstance(arg, numbers.Number ) else arg) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append( (target_to_name[target], arg_names, out_name)) else: raise RuntimeError('Unsupported opcode' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter( ) # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) # Specify names for positional input arguments interpreter.set_input_names(fn_input_names) # Load instructions interpreter.set_instructions(instructions) # Specify name for single output interpreter.set_output_name(mod.graph.result.name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): def __init__(self, interpreter): super().__init__() self.interpreter = interpreter wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter # 3) Returns the speficied return value # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root # module, however, I can't get `__call__` exposed on TorchBind classes # without it messing up Python `hasattr` for some reason. More digging # into CPython's implementation of hasattr is probably in order... graph = torch._fx.Graph() # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: placeholder_nodes.append(graph.create_node( 'placeholder', name)) # Get the interpreter object interpreter_node = graph.create_node('get_attr', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node(op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) # Register output graph.output(output_node) graph.lint(wrapper) # Return final GraphModule!!! return GraphModule(wrapper, graph)
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False): self.restore_state(model) if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) matches = self._find_matches(model.graph, self.modules, self.patterns) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized or non quantized environment, env: ' + str(env) + \ ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception("partially quantized inputs in list not handled yet") for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy(node, load_non_quantized) quantized = False else: result = obj.convert(self, node, load_arg) # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) else: quantized = True # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if is_activation_post_process(self.modules[node.target]): observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) continue if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops root_module = self.modules[''] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized)) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg)) module_dict = dict(model.named_modules()) to_be_removed = [] for name, module in model.named_modules(): if is_activation_post_process(module) and not is_submodule_of_fake_quant(name, module, module_dict): to_be_removed.append(name) for n in to_be_removed: delattr(model, n) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: self._qat_swap_modules(model) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized matches = self._find_matches(model.graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in model.graph.nodes: if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ get_observed_custom_module_class(type(custom_module)) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) mark_observed_custom_module(observed_custom_module, type(custom_module)) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) return model