def deepcopy_graph(gm: GraphModule) -> GraphModule: """ Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was traced with dynamic axes, and what were the values if that is the case. """ # First, create a copy of the module without the graph. graph = gm.__dict__.pop("_graph") fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(gm.__dict__) gm.__dict__["_graph"] = graph # Then, copy the graph. val_map = {} graph_clone = Graph() output_val = graph_clone.graph_copy(graph, val_map=val_map) graph_clone.output(output_val) # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. clone = gm.__class__(fake_mod, graph_clone) # Restore the dynamic axes related attributes to the clone. attributes = _cache_attributes(gm) attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} _restore_attributes_(clone, attributes) return clone
def quantize(self): self.quantized_graph = Graph() self.delegate = DelegateBase(self.quantized_graph) env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph)
def test_package_fx_with_imports(self): import package_a.subpackage # Manually construct a graph that invokes a leaf function graph = Graph() a = graph.placeholder("x") b = graph.placeholder("y") c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) d = graph.call_function(torch.sin, (c, )) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") input_x = torch.rand(2, 3) input_y = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))) # Check that the packaged version of the leaf_function dependency is # not the same as in the outer env. packaged_dependency = pi.import_module("package_a.subpackage") self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a, )) c = g.get_param('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d, )) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref)
def test_remove_uses(self): g: torch.fx.Graph = Graph() x: torch.fx.Node = g.placeholder('x') relu: torch.fx.Node = g.call_function(torch.relu, (x, )) neg: torch.fx.Node = g.call_function(torch.neg, (relu, )) g.output(neg) neg.replace_all_uses_with(relu) g.erase_node(neg) self.assertTrue(neg not in relu.users)
By the end of the tutorial, we'll have added the following method to an empty ``nn.Module`` class. .. code-block:: python def forward(self, x, y): cat_1 = torch.cat([x, y]); x = y = None tanh_1 = torch.tanh(cat_1); cat_1 = None neg_1 = torch.neg(tanh_1); tanh_1 = None return neg_1 ''' # Create a graph independently of symbolic tracing graph = Graph() tracer = torch.fx.proxy.GraphAppendingTracer(graph) # Create raw Nodes raw1 = graph.placeholder('x') raw2 = graph.placeholder('y') # Initialize Proxies using the raw Nodes and graph's default tracer y = Proxy(raw1, tracer) z = Proxy(raw2, tracer) # y = Proxy(raw1) # z = Proxy(raw2) # Create other operations using the Proxies `y` and `z` a = torch.cat([y, z]) b = torch.tanh(a)
def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { mod: name for name, mod in root.named_modules() } else: self.root = torch.nn.Module() fn = root tracer_cls: Optional[st.Type['Tracer']] = getattr( self, '__class__', None) self.graph = Graph(tracer_cls=tracer_cls) self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject], str] = {} def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: st.List[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, st.ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) for k, v in m.named_children(): collect_tensor_attrs(v, prefix_atoms + [k]) collect_tensor_attrs(self.root, []) assert isinstance(fn, st.FunctionType) fn_globals = fn.__globals__ # run before it gets patched fn, args = self.create_args_for_root( fn, isinstance(root, torch.nn.Module), concrete_args) parameter_proxy_cache: Dict[str, st.Proxy] = { } # Reduce number of get_attr calls @st.functools.wraps(st._orig_module_getattr) def module_getattr_wrapper(mod, attr): attr_val = st._orig_module_getattr(mod, attr) return self._module_getattr(attr, attr_val, parameter_proxy_cache) @st.functools.wraps(st._orig_module_call) def module_call_wrapper(mod, *args, **kwargs): def forward(*args, **kwargs): return st._orig_module_call(mod, *args, **kwargs) st._autowrap_check( patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}), self._autowrap_function_ids) return self.call_module(mod, forward, args, kwargs) with st._Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False) patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) patcher.patch_method(Aggregation, "__call__", module_call_wrapper, deduplicate=False) st._patch_wrapped_functions(patcher) st._autowrap_check(patcher, fn_globals, self._autowrap_function_ids) for module in self._autowrap_search: st._autowrap_check(patcher, module.__dict__, self._autowrap_function_ids) self.create_node( 'output', 'output', (self.create_arg(fn(*args)), ), {}, type_expr=fn.__annotations__.get('return', None)) self.submodule_paths = None return self.graph
class Quantizer: def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): self.root = mod self.graph = mod.graph self.quant_ctor = quant_ctor # cached information for observe self.state_dict = self.root.state_dict() self.modules = dict(self.root.named_modules()) # match the patterns that will get quantized self.matches = self._find_matches(patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an quant_ctor object for each self.quants = self._find_quants(quant_ctor) def observe(self, args): # most of this function is just an interpreter for the graph # it would be possible to put this in some abstraction, but # it is pretty nice to just be able to see exactly what is happening here # and hack on it. # maybe we should just provide an example interpreter that people copy/paste # then edit. args_iter = iter(args) env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) elif node.op == 'get_attr': result = self.state_dict[node.target] elif node.op == 'call_function': result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) elif node.op == 'call_method': self_obj, *args = load_arg(node.args) kwargs = load_arg(node.kwargs) result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) env[node.name] = result root_node, obj = self.matches.get(node.name, (None, None)) if root_node is node: obj.observe(node, env) if node.name in self.quants: self.quants[node.name].observe(node, env) return load_arg(self.graph.result) def quantize(self): self.quantized_graph = Graph() env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph) def _find_matches(self, patterns): modules = dict(self.root.named_modules()) match_map = {} # node name -> (root_node, match_value?) def apply_match(pattern, node, match): if isinstance(pattern, tuple): s, *args = pattern apply_match(s, node, match) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: match_map[node.name] = match for node in reversed(self.graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): if matches(modules, node, pattern): apply_match(pattern, node, (node, value(self, node))) return match_map def _find_quants(self, quant_ctor): quants = {} def visit_arg(n): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) if n.name not in quants: quants[n.name] = quant_ctor(self, n) for node in self.graph.nodes: if node.name in self.matches: map_arg(node.args, visit_arg) map_arg(node.kwargs, visit_arg) return quants
def call(self, graph_module: GraphModule) -> PassResult: """ Return a new copy of torch.fx.GraphModule with CSE applied to the input graph Example usage: from torch.fx.experimental.proxy_tensor import make_fx def f(a): b = a * a c = a * a return b+c p = CSEPass() traced_graph = make_fx(f)(torch.tensor(1)) print(traced_graph) result = p(traced_graph) print(result.graph_module) """ def get_aten_target(node): if hasattr(node.target, 'overloadpacket'): return node.target.overloadpacket return node.target modified = False new_graph = Graph() env: Dict[Node, Node] = { } # map from node in the old graph to node in the new graph hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new grpah without change # do not CSE away random operations if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target( n) in self.banned_ops: new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs memebrs to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted token = { "target": n.target, "args": args, "args_spec": args_spec, "kwargs": kwargs, "kwargs_spec": kwargs_spec } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) hash_val = (n.target, hash_arg) # check if a node has a substitute and can be eliminated hash_val_in_hash_env = hash_val in hash_env if hash_val_in_hash_env and token_map[hash_val] == token: modified = True # substition happens and the graph is modified env[n] = hash_env[hash_val] continue new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node if not hash_val_in_hash_env: hash_env[hash_val] = new_node token_map[hash_val] = token csed_gm = GraphModule(graph_module, new_graph) return PassResult(csed_gm, modified)