Beispiel #1
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
Beispiel #2
0
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_param('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)