def test_package_fx_with_imports(self): import package_a.subpackage # Manually construct a graph that invokes a leaf function graph = Graph() a = graph.placeholder("x") b = graph.placeholder("y") c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) d = graph.call_function(torch.sin, (c, )) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") input_x = torch.rand(2, 3) input_y = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))) # Check that the packaged version of the leaf_function dependency is # not the same as in the outer env. packaged_dependency = pi.import_module("package_a.subpackage") self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_graph_fns(self): g = Graph() a = g.placeholder('a') b = g.call_module('linear', (a, )) c = g.get_param('bias') d = g.call_method('add', (b, c)) e = g.call_function(torch.sin, (d, )) g.output(e) mod = torch.nn.Module() mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref)