def test_package_fx_with_imports(self): import package_a.subpackage # Manually construct a graph that invokes a leaf function graph = Graph() a = graph.placeholder("x") b = graph.placeholder("y") c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) d = graph.call_function(torch.sin, (c, )) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") input_x = torch.rand(2, 3) input_y = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))) # Check that the packaged version of the leaf_function dependency is # not the same as in the outer env. packaged_dependency = pi.import_module("package_a.subpackage") self.assertTrue(packaged_dependency is not package_a.subpackage)
def deepcopy_graph(gm: GraphModule) -> GraphModule: """ Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was traced with dynamic axes, and what were the values if that is the case. """ # First, create a copy of the module without the graph. graph = gm.__dict__.pop("_graph") fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(gm.__dict__) gm.__dict__["_graph"] = graph # Then, copy the graph. val_map = {} graph_clone = Graph() output_val = graph_clone.graph_copy(graph, val_map=val_map) graph_clone.output(output_val) # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. clone = gm.__class__(fake_mod, graph_clone) # Restore the dynamic axes related attributes to the clone. attributes = _cache_attributes(gm) attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} _restore_attributes_(clone, attributes) return clone
def test_remove_uses(self): g: torch.fx.Graph = Graph() x: torch.fx.Node = g.placeholder('x') relu: torch.fx.Node = g.call_function(torch.relu, (x, )) neg: torch.fx.Node = g.call_function(torch.neg, (relu, )) g.output(neg) neg.replace_all_uses_with(relu) g.erase_node(neg) self.assertTrue(neg not in relu.users)
def quantize(self): self.quantized_graph = Graph() self.delegate = DelegateBase(self.quantized_graph) env = {} quant_env = {} def load_arg(n, quantized): if not quantized: if n.name not in env and n.name in quant_env: env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] else: if n.name not in quant_env and n.name in env: quant_env[n.name] = self.quants[n.name].quantize( env[n.name]) return quant_env[n.name] def copy_recursive(node): def load_or_emit(n): if n.name in env or e.name in quant_env: return load_arg(n, quantized=False) else: return copy_recusive(n) r = env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) return r for node in self.graph.nodes: root_node, obj = self.matches.get(node.name, (None, None)) if root_node is None: # not quantized just copy it env[node.name] = self.quantized_graph.node_copy( node, lambda n: load_arg(n, quantized=False)) elif root_node is node: r = obj.quantize( self, node, lambda a: map_arg( a, lambda n: load_arg(n, quantized=True))) if r is NotImplemented: # quantizer choose to to quantize the node take the entire match, and just copy it over env[node.name] = copy_recursive(node) else: quant_env[node.name] = r self.quantized_graph.output( load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph)
def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a, )) c = g.get_param('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d, )) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref)
By the end of the tutorial, we'll have added the following method to an empty ``nn.Module`` class. .. code-block:: python def forward(self, x, y): cat_1 = torch.cat([x, y]); x = y = None tanh_1 = torch.tanh(cat_1); cat_1 = None neg_1 = torch.neg(tanh_1); tanh_1 = None return neg_1 ''' # Create a graph independently of symbolic tracing graph = Graph() tracer = torch.fx.proxy.GraphAppendingTracer(graph) # Create raw Nodes raw1 = graph.placeholder('x') raw2 = graph.placeholder('y') # Initialize Proxies using the raw Nodes and graph's default tracer y = Proxy(raw1, tracer) z = Proxy(raw2, tracer) # y = Proxy(raw1) # z = Proxy(raw2) # Create other operations using the Proxies `y` and `z` a = torch.cat([y, z]) b = torch.tanh(a)
def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { mod: name for name, mod in root.named_modules() } else: self.root = torch.nn.Module() fn = root tracer_cls: Optional[st.Type['Tracer']] = getattr( self, '__class__', None) self.graph = Graph(tracer_cls=tracer_cls) self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject], str] = {} def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: st.List[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, st.ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) for k, v in m.named_children(): collect_tensor_attrs(v, prefix_atoms + [k]) collect_tensor_attrs(self.root, []) assert isinstance(fn, st.FunctionType) fn_globals = fn.__globals__ # run before it gets patched fn, args = self.create_args_for_root( fn, isinstance(root, torch.nn.Module), concrete_args) parameter_proxy_cache: Dict[str, st.Proxy] = { } # Reduce number of get_attr calls @st.functools.wraps(st._orig_module_getattr) def module_getattr_wrapper(mod, attr): attr_val = st._orig_module_getattr(mod, attr) return self._module_getattr(attr, attr_val, parameter_proxy_cache) @st.functools.wraps(st._orig_module_call) def module_call_wrapper(mod, *args, **kwargs): def forward(*args, **kwargs): return st._orig_module_call(mod, *args, **kwargs) st._autowrap_check( patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}), self._autowrap_function_ids) return self.call_module(mod, forward, args, kwargs) with st._Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False) patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) patcher.patch_method(Aggregation, "__call__", module_call_wrapper, deduplicate=False) st._patch_wrapped_functions(patcher) st._autowrap_check(patcher, fn_globals, self._autowrap_function_ids) for module in self._autowrap_search: st._autowrap_check(patcher, module.__dict__, self._autowrap_function_ids) self.create_node( 'output', 'output', (self.create_arg(fn(*args)), ), {}, type_expr=fn.__annotations__.get('return', None)) self.submodule_paths = None return self.graph
def call(self, graph_module: GraphModule) -> PassResult: """ Return a new copy of torch.fx.GraphModule with CSE applied to the input graph Example usage: from torch.fx.experimental.proxy_tensor import make_fx def f(a): b = a * a c = a * a return b+c p = CSEPass() traced_graph = make_fx(f)(torch.tensor(1)) print(traced_graph) result = p(traced_graph) print(result.graph_module) """ def get_aten_target(node): if hasattr(node.target, 'overloadpacket'): return node.target.overloadpacket return node.target modified = False new_graph = Graph() env: Dict[Node, Node] = { } # map from node in the old graph to node in the new graph hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new grpah without change # do not CSE away random operations if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target( n) in self.banned_ops: new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs memebrs to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted token = { "target": n.target, "args": args, "args_spec": args_spec, "kwargs": kwargs, "kwargs_spec": kwargs_spec } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) hash_val = (n.target, hash_arg) # check if a node has a substitute and can be eliminated hash_val_in_hash_env = hash_val in hash_env if hash_val_in_hash_env and token_map[hash_val] == token: modified = True # substition happens and the graph is modified env[n] = hash_env[hash_val] continue new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node if not hash_val_in_hash_env: hash_env[hash_val] = new_node token_map[hash_val] = token csed_gm = GraphModule(graph_module, new_graph) return PassResult(csed_gm, modified)