Ejemplo n.º 1
0
    def test_module_repr(self):
        class Submodule(nn.Module):
            def forward(self, x):
                return x

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = nn.Conv2d(10, 10, 3)
                self.lin = nn.Linear(10, 10)
                self.sub = Submodule()

            def forward(self, x):
                return self.lin(x) + self.sub(x) + self.conv(x)

        m = torch.jit.script(MyModule())

        with self.capture_stdout() as out:
            print(m)

        f = FileCheck()
        f.check('MyModule')
        f.check('Conv2d')
        f.check('Linear')
        f.check('Submodule')
        f.run(out[0])

        self.assertEqual(m.original_name, 'MyModule')
Ejemplo n.º 2
0
        def validate_transformed_module(module_name,
                                        pattern_count_map,
                                        data_shape,
                                        prepack_removal=False):
            scripted_model = torch.jit.script(module_name())
            scripted_model.eval()
            input_data = torch.rand(data_shape)
            ref_result = scripted_model(input_data)
            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
            if (prepack_removal):
                scripted_model._c = torch._C._freeze_module(scripted_model._c)
                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)

            buffer = io.BytesIO()
            torch.jit.save(scripted_model, buffer)
            buffer.seek(0)
            deserialized_scripted_model = torch.jit.load(buffer)
            file_check = FileCheck()
            for pattern, v in pattern_count_map.items():
                if (v == 0):
                    file_check.check(pattern)
                elif (v == -1):
                    file_check.check_not(pattern)
                else:
                    file_check.check_count(pattern, v, exactly=True)
            file_check.run(deserialized_scripted_model.graph)
            xnnpack_result = deserialized_scripted_model(input_data)
            torch.testing.assert_allclose(ref_result,
                                          xnnpack_result,
                                          rtol=1e-2,
                                          atol=1e-3)
Ejemplo n.º 3
0
    def test_fake_dispatch_keys(self):
        with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            x = torch.rand([4])
            f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
            f.run(torch._C._dispatch_key_set(x))

            with torch.inference_mode():
                x = torch.rand([4])
                y = x + x
                FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
Ejemplo n.º 4
0
    def test_torchbind_return_instance(self):
        def foo():
            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
            return ss

        scripted = torch.jit.script(foo)
        # Ensure we are creating the object and calling __init__
        # rather than calling the __init__wrapper nonsense
        fc = FileCheck().check('prim::CreateObject()')\
                        .check('prim::CallMethod[name="__init__"]')
        fc.run(str(scripted.graph))
        out = scripted()
        self.assertEqual(out.pop(), "mom")
        self.assertEqual(out.pop(), "hi")
Ejemplo n.º 5
0
    def test_error_stack_class(self):
        class X(object):
            def bad_fn(self):
                import pdb  # noqa: F401

        def fn(x) -> X:
            return X(10)

        try:
            torch.jit.script(fn)
        except Exception as e:
            checker = FileCheck()
            checker.check("import statements")
            checker.check("is being compiled since it was called from")
            checker.run(str(e))
Ejemplo n.º 6
0
    def test_error_stack(self):
        def d(x: int) -> int:
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        def a(x):
            return b(x)

        try:
            scripted = torch.jit.script(a)
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("def c(x)")
            checker.check("def b(x)")
            checker.check("def a(x)")
            checker.run(str(e))
Ejemplo n.º 7
0
    def test_error_stack_module(self):
        def d(x):
            # type: (int) -> int
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        class Submodule(torch.nn.Module):
            def __init__(self):
                super(Submodule, self).__init__()

            def forward(self, x):
                return b(x)

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.submodule = Submodule()

            def some_method(self, y):
                return y + self.submodule(y)

            def forward(self, x):
                return self.some_method(x)

        try:
            scripted = torch.jit.script(M())
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("'c' is being compiled since it was called from 'b'")
            checker.check("'b' is being compiled since it was called from")
            checker.run(str(e))