def _patch_arguments_( gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True ): """ Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples and slices). """ def _patch_slice(s, mapping): return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step)) graph = gm.graph supported_types = (Node, str, int, float) for node in graph.nodes: new_args = [] for arg in node.args: if isinstance(arg, tuple): new_arg = [] for a in arg: if isinstance(a, slice): new_arg.append(_patch_slice(a, mapping)) else: new_arg.append(mapping.get(a, a)) new_args.append(tuple(new_arg)) elif isinstance(arg, slice): new_args.append(_patch_slice(arg, mapping)) elif isinstance(arg, supported_types): new_args.append(mapping.get(arg, arg)) else: new_args.append(arg) node.args = tuple(new_args) if lint_and_recompile: graph.lint() gm.recompile()
def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule: # Get output node output_node: Optional[Node] = None for n in reversed(m.graph.nodes): if n.op == "output": output_node = n break assert output_node # Get the actual output (the "input" of the output node). This is # the Node we want to wrap in a user-specified activation function assert len(output_node.all_input_nodes) == 1 wrap_node = output_node.all_input_nodes[0] # Wrap the actual output in a Proxy wrap_proxy = Proxy(wrap_node) # Get the implementation of the specified activation function and # symbolically trace it fn_impl = activation_functions[fn] fn_impl_traced = symbolic_trace(fn_impl) # Call the specified activation function using the Proxy wrapper for # `output_op`. The result of this call is another Proxy, which we # can hook into our existing Graph. with traced.graph.inserting_after(wrap_node): fn_impl_output_node = fn_impl_traced(wrap_proxy) new_args = (fn_impl_output_node.node, ) output_node.args = new_args m.recompile()
def _patch_getitem_( gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True ): """Patches getitem nodes by replacing current arguments to their corresponding values in mapping.""" # TODO: combine this with the patch_argument function which seems to do almost the same thing. graph = gm.graph for node in graph.nodes: if node.op == "call_function" and node.target == operator.getitem: indices = node.args[1] if isinstance(indices, tuple): new_indices = [] for idx in indices: if isinstance(idx, slice): new_indices.append( slice( mapping.get(idx.start, idx.start), mapping.get(idx.stop, idx.stop), mapping.get(idx.step, idx.step), ) ) elif isinstance(idx, int): new_indices.append(mapping.get(idx, idx)) else: new_indices.append(idx) node.args = (node.args[0], tuple(new_indices)) else: node.args = (node.args[0], mapping.get(node.args[1], node.args[1])) if lint_and_recompile: graph.lint() gm.recompile()
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model.""" graph = gm.graph input_names = set(gm.dummy_inputs.keys()) encoder_sequence_length_node = None for node in graph.nodes: if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name: with graph.inserting_after(node): # There are two cases to handle: # 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the # input shapes is [batch_size, sequence_length] => index 1 # 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input # shape is [batch_size, num_choices, sequence_length] => index 2 encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2)) if encoder_sequence_length_node is None: raise ValueError("Could not insert the node that computes the encoder sequence length") if lint_and_recompile: graph.lint() gm.recompile() # Useful when retracing for quantization. if hasattr(gm, "_qconfig_map"): gm._qconfig_map[encoder_sequence_length_node.name] = None return encoder_sequence_length_node
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False): """Transformation that enables traced models to perform inference on dynamic input shapes.""" graph = gm.graph static2dynamic = {} # Inserting the nodes that will fetch the batch size and sequence lengths dynamically. if gm.use_dynamic_batch_size: batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False) static2dynamic[gm.static_batch_size] = batch_size_node if gm.num_choices > 0: with graph.inserting_after(batch_size_node): static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function( operator.mul, args=(batch_size_node, gm.num_choices) ) # Useful when retracing for quantization. if hasattr(gm, "_qconfig_map"): gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None if gm.use_dynamic_sequence_length: encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False) static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node # TODO: do the same for the decoder. pass _change_view_methods_(gm, static2dynamic, lint_and_recompile=False) _patch_getitem_(gm, static2dynamic, lint_and_recompile=False) remove_unused_nodes_(gm, lint_and_recompile=False) graph.lint() gm.recompile() gm.static2dynamic = static2dynamic gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
def tensorexpr_compile(fx_module: fx.GraphModule, flat_args) -> Callable: """Compiles the given fx_module using TensorExpr Kernel""" inp_devices = {i.device for i in flat_args if isinstance(i, torch.Tensor)} assert len(inp_devices) == 1 inp_device = list(inp_devices)[0] inputs = [] output_refs = [] for node in fx_module.graph.nodes: if node.op == "placeholder": inputs.append(node) elif node.op == "output": outputs = node.args[0] if not isinstance(outputs, Iterable): outputs = (outputs, ) new_outputs = [] for idx, output in enumerate(outputs): # Appends (bool, idx) pairs # if True, read from kernel outputs # if False, read from kernel inputs if output in inputs: output_refs.append((False, inputs.index(output))) elif output in outputs[:idx]: output_refs.append( (True, output_refs[outputs.index(output)][1])) else: output_refs.append((True, len(new_outputs))) new_outputs.append(output) node.args = (tuple(new_outputs), ) fx_module.graph.lint() fx_module.recompile() for i in range(0, 100): attr = f"_tensor_constant{i}" if hasattr(fx_module, attr): setattr(fx_module, attr, getattr(fx_module, attr).to(inp_device)) else: break jit_module = torch.jit.trace(fx_module, flat_args) jit_module = torch.jit.freeze(jit_module.eval()) torch._C._jit_trace_module(jit_module._c, tuple(flat_args)) torch._C._te.remove_unused_self_argument(jit_module.graph) torch._C._te.annotate_input_shapes(jit_module.graph, tuple(flat_args)) torch._C._jit_pass_lower_all_tuples(jit_module.graph) te_kernel = torch._C._te.TensorExprKernel(jit_module.graph) def f(*args): outs = te_kernel.run(args) if not isinstance(outs, tuple) and not isinstance(outs, list): outs = (outs, ) real_outs = [] for out in output_refs: if out[0]: real_outs.append(outs[out[1]]) else: real_outs.append(args[out[1]]) return real_outs return f
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True): """Removes all the unused nodes in a GraphModule.""" graph = gm.graph for node in graph.nodes: if not node.users and node.op not in ["placeholder", "output"]: graph.erase_node(node) if lint_and_recompile: graph.lint() gm.recompile()
def ts_compile(fx_g: fx.GraphModule, _) -> Callable: """ Compiles the :attr:`fx_g` with Torchscript compiler. .. warning:: This API is experimental and likely to change. Args: fx_g(fx.GraphModule): The input Fx graph module to be compiled. Returns: Torch scripted model. """ for node in fx_g.graph.nodes: if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty): if node.args[1] == []: args = list(node.args) args[1] = [1] node.args = tuple(args) elif node.target is torch.ops.aten.masked_fill and node.args[ 2] == float("-inf"): # Fx graph to torchscript fails for -inf args = list(node.args) args[2] = -3.403 * 10**37 node.args = tuple(args) for node in fx_g.graph.nodes: new_kwargs = {} for k, v in node.kwargs.items(): if isinstance(v, torch.device): v = v.type new_kwargs[k] = v node.kwargs = new_kwargs fx_g.graph.lint() # print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function'])) # Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311 for i in range(1000): attr = f"_tensor_constant{i}" if hasattr(fx_g, attr): setattr(fx_g, attr, getattr(fx_g, attr).cuda()) else: break fx_g.recompile() f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) f = torch.jit.freeze(f.eval()) f = torch.jit.optimize_for_inference(f) return f
def _change_view_methods_( gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True ): """ Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the batch_size / sequence_length nodes. """ graph = gm.graph for node in graph.nodes: if node.op == "call_method" and node.target == "view": if isinstance(node.args[1], tuple): node.args = (node.args[0], *node.args[1]) node.args = tuple((mapping.get(arg, arg) for arg in node.args)) if lint_and_recompile: graph.lint() gm.recompile()
def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult: output_nodes = [n for n in gm.graph.nodes if n.op == "output"] assert len(output_nodes) == 1, \ f"Expecting exactly one `output` node, but got {len(output_nodes)}" changed = False # arg node name to its index in the new output args tuple name_to_idx: t.Dict[str, int] = {} output_node = output_nodes[0] # Output op only uses its `args[0]`, and it does not have `kwargs`. # https://pytorch.org/docs/stable/fx.html#torch.fx.Node args: t.Sequence[t.Any] = output_node.args[0] # Only concern outselves to the case where the args is an iterable of fx.Node. # Other return cases (e.g., a single value) is possible and we don't handle # that in this pass. if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)): return RemoveDuplicateResult(replacement_map=None, module=gm) # Map old index of the arg node to the remaining node's idx, # initialized to `i => i` replacement_map: t.List[int] = list(range(len(args))) args_new = [] for idx, a in enumerate(args): assert isinstance(a, fx.Node), \ f"Expecting fx.Node instance, but got: {type(a)}" if a.name not in name_to_idx: args_new.append(a) name_to_idx[a.name] = len(args_new) - 1 else: changed = True _LOGGER.warning( f"Replaced duplicate output arg '{a.name}': " f"{idx} -> {name_to_idx[a.name]}" ) replacement_map[idx] = name_to_idx[a.name] output_node.args = (tuple(args_new),) if changed: gm.recompile() return RemoveDuplicateResult(replacement_map, module=gm)
def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: """ Compiles the :attr:`fx_g` with Torchscript compiler. .. warning:: This API is experimental and likely to change. Args: fx_g(fx.GraphModule): The input Fx graph module to be compiled. Returns: Torch scripted model. """ with _disable_jit_autocast(): strip_overloads(fx_g) for node in fx_g.graph.nodes: if (node.target == torch.ops.aten._to_copy and len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs): node.target = torch.ops.aten.to for node in fx_g.graph.nodes: new_kwargs = {} for k, v in node.kwargs.items(): if isinstance(v, torch.device): v = v.type new_kwargs[k] = v node.kwargs = new_kwargs fx_g.graph.lint() fx_g.recompile() f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) f = torch.jit.freeze(f.eval()) f = torch.jit.optimize_for_inference(f) f(*inps) return f
def remove_duplicate_output_args( top_level: fx.GraphModule, target_subnets: t.Collection[str] ) -> t.Mapping[str, "RemoveDuplicateResult"]: """Removes duplicate output args. This pass removes duplicate output args from the target subnets and fixes their uses in the top level module where the subnets are called. This pass must be called after acc split on the top-level net and subsequent calls to the acc trace on the subnets. This pass will change both the subnets and top level module. Returns: a mapping of the target subnet name to its dedupcate result """ processed_subnets = {} for node in top_level.graph.nodes: # type: fx.Node if node.op == "call_module" and node.name in target_subnets: assert isinstance(node.target, str) sub_gm = top_level.get_submodule(node.target) assert isinstance(sub_gm, fx.GraphModule) replace_res = _remove_duplicate_output_args(sub_gm) processed_subnets[node.name] = replace_res if replace_res.replacement_map is None: continue sub_gm.recompile() needs_recompile = False # iterate on the copy since we will be changing elements of node.users for user in list(node.users): idx = _ensure_proper_output_use(user, node) idx_new = replace_res.replacement_map[idx] if idx_new != idx: user.args = (user.args[0], idx_new) needs_recompile = True if needs_recompile: top_level.recompile() return processed_subnets
def change_fx_graph_return_to_tuple(fx_g: fx.GraphModule): for node in fx_g.graph.nodes: if node.op == "output": # output nodes always have one argument node_arg = node.args[0] out_nodes = [] if isinstance(node_arg, list): # Don't return NoneType elements. for out_node in node_arg: if not isinstance(out_node, type(None)): out_nodes.append(out_node) # If there is a single tensor/element to be returned don't # create a tuple for it. if len(out_nodes) == 1: node.args = out_nodes else: node.args = (tuple(out_nodes), ) fx_g.graph.lint() fx_g.recompile() return fx_g
def set_trace(gm: fx.GraphModule) -> fx.GraphModule: """ Sets a breakpoint in `gm`'s generated python code. It drops into pdb when `gm` gets run. Args: gm: graph module to insert breakpoint. It is then recompiled for it to take effect. Returns: the `gm` with breakpoint inserted. """ def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] with gm.graph.on_generate_code(make_transformer=lambda cur_transform: ( # new code transformer to register lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)))): gm.recompile() return gm
def force_lazy_device(model: fx.GraphModule): """ Factory methods in a Fx graph may create tensors for a specific eager devices. If we take no actions, those eager tensors will be mixed with lazy tensors and cause crash. This method overwrite those eager device to lazy device. """ def tolazydevice(dev): if isinstance(dev, torch.device): return torch.device("lazy", index=dev.index) return dev def hasDeviceArg(args, kwargs): return any( isinstance(arg, torch.device) for arg in itertools.chain(args, kwargs.values())) for nd in model.graph.nodes: nd.args = tuple(tolazydevice(arg) for arg in nd.args) nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return # eager tensors on the default device # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). # To force those tensors on the lazy device, we can not simply override # the device argument since there is no explicit device argument. # What we are doing here is, for the list of covered tensor factory methods # we add a lazy device argument explicity. # # TODO: This solution is no ideal since we may miss some factory methods. In future # when we support lazy mode, this method can be replaced by that. if nd.target in tensor_factory_functions and not hasDeviceArg( nd.args, nd.kwargs): kwargs = dict( nd.kwargs) # nd.kwargs is immutable. make a mutable copy. kwargs["device"] = torch.device("lazy") nd.kwargs = kwargs model.recompile()
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: """Inserts a node that retrieves the batch size dynamically from the input of the model.""" graph = gm.graph input_names = set(gm.dummy_inputs.keys()) batch_size_node = None for node in graph.nodes: if node.op == "placeholder" and node.name in input_names: with graph.inserting_after(node): batch_size_node = graph.call_method("size", args=(node, 0)) if batch_size_node is None: raise ValueError("Could not insert the node that computes the batch size") if lint_and_recompile: graph.lint() gm.recompile() # Useful when retracing for quantization. if hasattr(gm, "_qconfig_map"): gm._qconfig_map[batch_size_node.name] = None return batch_size_node
def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs ) -> Tuple[fx.GraphModule, fx.GraphModule]: """ Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation. To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. Returns: Returns the generated forward and backward Fx graph modules. """ try: import networkx as nx except ImportError: raise RuntimeError("Need networkx installed to perform smart recomputation heuristics") joint_module.graph.eliminate_dead_code() joint_module.recompile() fx_g = joint_module.graph # add the CSE pass cse_graph = fx_graph_cse(fx_g) joint_module.graph = cse_graph full_bw_graph = joint_module.graph name_to_node = {} for node in joint_module.graph.nodes: name_to_node[node.name] = node def classify_nodes(joint_module): required_bw_nodes = set() for node in joint_module.graph.nodes: if node.op == 'placeholder' and "tangents" in node.target: required_bw_nodes.add(node) if node in required_bw_nodes: for user in node.users: required_bw_nodes.add(user) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module) forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes if node.op != 'output'} unclaimed_nodes = {node for node in joint_module.graph.nodes if node not in required_fw_nodes and node not in required_bw_nodes} return required_fw_nodes, required_bw_nodes, unclaimed_nodes required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module) for node in reversed(joint_module.graph.nodes): if node not in required_fw_nodes: node.dist_from_bw = 0 else: node.dist_from_bw = int(1e9) for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) aten = torch.ops.aten pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501 misc_ops = [aten.to, aten.type_as, operator.getitem] reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501 # not recomputed by default since these are kinda expensive/hard to fuse into # norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501 # Not used by default since NVFuser can't fuse view ops # view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat] # noqa: E501 # These are the view ops that NVFuser can fuse view_ops = [aten.squeeze, aten.unsqueeze] random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d] # noqa: E501 unrecomputable_ops = random_ops + compute_intensive_ops recomputable_ops = set( pointwise_ops + misc_ops + reduction_ops + view_ops ) fusible_ops = recomputable_ops | set(random_ops) AGGRESSIVE_RECOMPUTATION = False def ban_recomputation(node): if AGGRESSIVE_RECOMPUTATION: return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops) else: if node.op != 'call_function': return False if get_aten_target(node) not in recomputable_ops: return True # If the output of the reduction is 4x smaller (arbitrary choice), # then we don't allow recomputation. if get_aten_target(node) in reduction_ops: input_tensors_size = sum(_size_of(i.meta['tensor_meta']) for i in node.args if isinstance(i, fx.Node)) output_size = _size_of(node.meta['tensor_meta']) return (output_size * 4 < input_tensors_size) return False def is_fusible(a, b): return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops def is_materialized(node): if node.op == 'placeholder': return True return not all(is_fusible(node, user) for user in node.users) def get_node_weight(node): mem_sz = _size_of(node.meta['tensor_meta']) # Heuristic to bias towards nodes closer to the backwards pass mem_sz = int(mem_sz + node.dist_from_bw) if is_materialized(node): return mem_sz else: return mem_sz * 2 nx_graph = nx.DiGraph() for node in full_bw_graph.nodes: if node.op == 'output': continue if node in required_bw_nodes: nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) continue if node.op == 'placeholder' and "primals" in node.target: nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) # If a node can't be recomputed (too expensive or involves randomness), # we prevent it from being recomputed by adding an inf edge to the source # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. if ban_recomputation(node) and node in required_fw_nodes: nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) if 'tensor_meta' not in node.meta: weight = math.inf else: weight = get_node_weight(node) # Creates the weights on the "node" edge nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) for user in node.users: nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") reachable, non_reachable = partition cutset = set() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) cut_nodes = set() for node_in, node_out in cutset: assert node_in[:-3] == node_out[:-4] node_name = node_in[:-3] cut_nodes.add(node_name) # To make this stuff deterministic node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]) return _extract_fwd_bwd_modules(joint_module, saved_values)
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> None: """ Matches all possible non-overlapping sets of operators and their data dependencies (``pattern``) in the Graph of a GraphModule (``gm``), then replaces each of these matched subgraphs with another subgraph (``replacement``). Args: ``gm``: The GraphModule that wraps the Graph to operate on ``pattern``: The subgraph to match in ``gm`` for replacement ``replacement``: The subgraph to replace ``pattern`` with Examples: .. code-block:: python import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) The above code will first match ``pattern`` in the ``forward`` method of ``traced_module``. Pattern-matching is done based on use-def relationships, not node names. For example, if you had ``p = torch.cat([a, b])`` in ``pattern``, you could match ``m = torch.cat([a, b])`` in the original ``forward`` function, despite the variable names being different (``p`` vs ``m``). The ``return`` statement in ``pattern`` is matched based on its value only; it may or may not match to the ``return`` statement in the larger graph. In other words, the pattern doesn't have to extend to the end of the larger graph. When the pattern is matched, it will be removed from the larger function and replaced by ``replacement``. If there are multiple matches for ``pattern`` in the larger function, each non-overlapping match will be replaced. In the case of a match overlap, the first found match in the set of overlapping matches will be replaced. ("First" here being defined as the first in a topological ordering of the Nodes' use-def relationships. In most cases, the first Node is the parameter that appears directly after ``self``, while the last Node is whatever the function returns.) One important thing to note is that the parameters of the ``pattern`` Callable must be used in the Callable itself, and the parameters of the ``replacement`` Callable must match the pattern. The first rule is why, in the above code block, the ``forward`` function has parameters ``x, w1, w2``, but the ``pattern`` function only has parameters ``w1, w2``. ``pattern`` doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. As an example of the second rule, consider replacing .. code-block:: python def pattern(x, y): return torch.neg(x) + torch.relu(y) with .. code-block:: python def replacement(x, y): return torch.relu(x) In this case, ``replacement`` needs the same number of parameters as ``pattern`` (both ``x`` and ``y``), even though the parameter ``y`` isn't used in ``replacement``. After calling ``subgraph_rewriter.replace_pattern``, the generated Python code looks like this: .. code-block:: python def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2 """ # Get the graphs for `gm`, `pattern`, `replacement` original_graph = gm.graph pattern_graph = symbolic_trace(pattern).graph replacement_graph = symbolic_trace(replacement).graph # Find all possible pattern matches in original_graph. Note that # pattern matches may overlap with each other. matcher = SubgraphMatcher(pattern_graph) matches: List[Match] = [] # Consider each node as an "anchor" (deepest matching graph node) for anchor in original_graph.nodes: if matcher.matches_subgraph_from_anchor(anchor): def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} for n in lookup.keys(): if n.op == "placeholder" or lookup[n].op == "output": continue for user in n.users: # If this node has users that were not in # `lookup`, then it must leak out of the # pattern subgraph if user not in lookup: return False return True # It's not a match if the pattern leaks out into the rest # of the graph if pattern_is_contained(matcher.nodes_map): for k, v in matcher.nodes_map.items(): # Shallow copy nodes_map matches.append( Match(anchor=anchor, nodes_map=copy.copy(matcher.nodes_map))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() # Return TRUE if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: for n in match.nodes_map.values(): if n in replaced_nodes and n.op != "placeholder": return True return False for match in matches: # Skip overlapping matches if overlaps_with_prev_match(match): continue # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [ n for n in pattern_graph.nodes if n.op == "placeholder" ] assert len(pattern_placeholders) replacement_placeholders = [ n for n in replacement_graph.nodes if n.op == "placeholder" ] assert len(pattern_placeholders) == len(replacement_placeholders) placeholder_map = { r: p for r, p in zip(replacement_placeholders, pattern_placeholders) } # node from `original_graph` that matched with the output node # in `pattern` subgraph_output: Node = match.anchor def mark_node_as_replaced(n: Node) -> None: if n not in match.nodes_map.values(): return for n_ in n.all_input_nodes: mark_node_as_replaced(n_) replaced_nodes.add(n) mark_node_as_replaced(subgraph_output) # Intialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] original_graph_node = match.nodes_map[pattern_node] # Populate `val_map` val_map[replacement_node] = original_graph_node # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): copied_output = original_graph.graph_copy(replacement_graph, val_map) assert isinstance(copied_output, Node) # We only want to copy in the output node from `pattern` if we # have an output-output match. Otherwise, we leave out the # `pattern` output node so we don't have two outputs in the # resultant graph if subgraph_output.op != "output": subgraph_output = subgraph_output.args[0] # type: ignore subgraph_output.replace_all_uses_with(copied_output) # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output": original_graph.erase_node(node) # Update the passed-in GraphModule to reflect the new state of # `original_graph` gm.recompile()
def convert( model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True, convert_qconfig_dict: Dict[str, Any] = None, backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: 1. for each observer module call in the graph, we'll convert it to calls to quantize and dequantize functions based on the observer instance 2. for weighted operations like linear/conv, we need to convert them to reference quantized module, this requires us to know whether the dtype configured for the weight is supported in the backend, this is done in prepare step and the result is stored in observed_node_names, we can decide whether we need to swap the module based on this set standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model) qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # TODO this should be removed now that gpu support for quantization is being supported. # however in practice, as of 7/22/2021, certain functions that get called by convert expect # only cpu arguments. # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda', # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend. if not is_reference: model.cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) # TODO refactor this code once we update the prepare logic to have additional information on # which graph nodes have been observed and share that with convert to decide which observers to ignore. if convert_qconfig_dict: prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict # type: ignore[assignment] modules_copy = copy.deepcopy(modules) convert_dict_to_ordered_dict(convert_qconfig_dict) if model._is_qat: convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, {}) convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict) compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict) # type: ignore[arg-type] convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope) # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map # or are set to None in the convert_qconfig_map. for k, v in qconfig_map.items(): assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k) if convert_qconfig_map[k] is not None: assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \ and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k]) qconfig_map = convert_qconfig_map custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) # TODO: move this outside of this function def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys())]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) node.replace_all_uses_with(dequantized_node) graph.erase_node(node) # this is a temporary hack for custom module, we may want to implement # this properly after the custom module class design is finalized def replace_observer_with_dequantize_node(node: Node, graph: Graph): call_custom_module_node = node.args[0] assert isinstance(call_custom_module_node, Node), \ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) if backend_config_dict is None: backend_config_dict = get_native_backend_config_dict() root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict) # convert tuples so that it can work with isinstance(module, tuple_of_classes) root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config_dict) fused_module_classes = get_fused_module_classes(backend_config_dict) statically_quantized_custom_module_nodes: Set[Node] = set() for node in list(model.graph.nodes): if node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: # Inputs are assumed to be quantized if the user specifid the # input_quantized_idxs override. # we need to dequantize the inputs since all operators took # floating point inputs in reference quantized models insert_dequantize_node(node, model.graph) elif node.op == "output": # If the argument is empty we don't need to do anything if len(output_quantized_idxs) == 0: continue # Result are kept quantized if the user specified the # output_quantized_idxs override. # Remove the dequantize operator for the node in the end if any return_node = node output = node.args[0] # outputs can be Node, list, tuple, dict, other cases are not supported yet if isinstance(output, (list, tuple)): for idx in output_quantized_idxs: maybe_recursive_remove_dequantize(output[idx], return_node, model.graph) elif isinstance(output, (Node, dict)): # we treat dict as a single argument currently, but it can be extended # to support {"key": dtype} after we change output_quantized_idxs to # dict if 0 in output_quantized_idxs: maybe_recursive_remove_dequantize(output, return_node, model.graph) else: warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}") elif node.op == "call_module": if is_activation_post_process(modules[node.target]): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: replace_observer_with_dequantize_node(node, model.graph) else: replace_observer_with_quantize_dequantize_node( model, model.graph, node, modules, node_name_to_scope, qconfig_map) elif is_observed_standalone_module(modules[node.target]): convert_standalone_module( node, modules, model, is_reference, backend_config_dict) elif type(modules[node.target]) in set( root_module_classes).union(qat_module_classes).union(fused_module_classes): # extra check for fused module classes to make sure they are fused module classes # of target modules if type(modules[node.target]) in fused_module_classes and \ type(modules[node.target][0]) not in root_module_classes: continue convert_weighted_module( node, modules, observed_node_names, qconfig_map, backend_config_dict) elif type(modules[node.target]) in custom_module_classes: convert_custom_module( node, model.graph, modules, custom_module_class_mapping, statically_quantized_custom_module_nodes) preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes) # remove deadcode after converting observers to quant/dequant ops model.graph.eliminate_dead_code() model.recompile() # TODO: maybe move this to quantize_fx.py if not is_reference: model = duplicate_dequantize_node(model) model = duplicate_quantize_dynamic_node(model) model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope) model = remove_quant_dequant_pairs(model) model = remove_extra_dequantize(model) # TODO: this looks hacky, we want to check why we need this and see if we can # remove this # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) return model