def test_package_fx_with_imports(self): import package_a.subpackage # Manually construct a graph that invokes a leaf function graph = Graph() a = graph.placeholder("x") b = graph.placeholder("y") c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) d = graph.call_function(torch.sin, (c, )) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") input_x = torch.rand(2, 3) input_y = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))) # Check that the packaged version of the leaf_function dependency is # not the same as in the outer env. packaged_dependency = pi.import_module("package_a.subpackage") self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a, )) c = g.get_param('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d, )) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref)
.. code-block:: python def forward(self, x, y): cat_1 = torch.cat([x, y]); x = y = None tanh_1 = torch.tanh(cat_1); cat_1 = None neg_1 = torch.neg(tanh_1); tanh_1 = None return neg_1 ''' # Create a graph independently of symbolic tracing graph = Graph() tracer = torch.fx.proxy.GraphAppendingTracer(graph) # Create raw Nodes raw1 = graph.placeholder('x') raw2 = graph.placeholder('y') # Initialize Proxies using the raw Nodes and graph's default tracer y = Proxy(raw1, tracer) z = Proxy(raw2, tracer) # y = Proxy(raw1) # z = Proxy(raw2) # Create other operations using the Proxies `y` and `z` a = torch.cat([y, z]) b = torch.tanh(a) c = torch.neg(b) # By using the graph's own appending tracer to create Proxies, # notice we can now use n-ary operators on operations without # multiple tracers being created at run-time (line 52) which leads
.. code-block:: python def forward(self, x, y): cat_1 = torch.cat([x, y]); x = y = None tanh_1 = torch.tanh(cat_1); cat_1 = None neg_1 = torch.neg(tanh_1); tanh_1 = None return neg_1 """ # Create a graph independently of symbolic tracing graph = Graph() # Create raw Nodes raw1 = graph.placeholder("x") raw2 = graph.placeholder("y") # Initialize Proxies using the raw Nodes y = Proxy(raw1) z = Proxy(raw2) # Create other operations using the Proxies `y` and `z` a = torch.cat([y, z]) b = torch.tanh(a) c = torch.neg(b) # Create a new output Node and add it to the Graph. By doing this, the # Graph will contain all the Nodes we just created (since they're all # linked to the output Node) graph.output(c.node)