Ejemplo n.º 1
0
        def validate_transformed_module(module_name,
                                        pattern_count_map,
                                        data_shape,
                                        prepack_removal=False):
            scripted_model = torch.jit.script(module_name())
            scripted_model.eval()
            input_data = torch.rand(data_shape)
            ref_result = scripted_model(input_data)
            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
            if (prepack_removal):
                scripted_model._c = torch._C._freeze_module(scripted_model._c)
                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)

            buffer = io.BytesIO()
            torch.jit.save(scripted_model, buffer)
            buffer.seek(0)
            deserialized_scripted_model = torch.jit.load(buffer)
            file_check = FileCheck()
            for pattern, v in pattern_count_map.items():
                if (v == 0):
                    file_check.check(pattern)
                elif (v == -1):
                    file_check.check_not(pattern)
                else:
                    file_check.check_count(pattern, v, exactly=True)
            file_check.run(deserialized_scripted_model.graph)
            xnnpack_result = deserialized_scripted_model(input_data)
            torch.testing.assert_allclose(ref_result,
                                          xnnpack_result,
                                          rtol=1e-2,
                                          atol=1e-3)
Ejemplo n.º 2
0
 def test_lstm_traced_cpu(self):
     inputs = get_lstm_inputs('cpu')
     try:
         ge = self.checkTrace(LSTMCellF, inputs)
         graph = ge.graph_for(*inputs)
         FileCheck.check("FusionGroup").run(str(graph))
     except RuntimeError as e:
         if 'Failed to compile' in e.args[0]:
             warnings.warn('CPU fuser test has failed! This is not a hard failure, '
                           'because the kernels sometimes trigger bugs in compilers '
                           '(most notably GCC 7.2).')
             raise unittest.SkipTest('Failed to compile')
         else:
             raise
Ejemplo n.º 3
0
    def test_error_stack_class(self):
        class X(object):
            def bad_fn(self):
                import pdb  # noqa: F401

        def fn(x) -> X:
            return X(10)

        try:
            torch.jit.script(fn)
        except Exception as e:
            checker = FileCheck()
            checker.check("import statements")
            checker.check("is being compiled since it was called from")
            checker.run(str(e))
Ejemplo n.º 4
0
    def test_module_repr(self):
        class Submodule(nn.Module):
            def forward(self, x):
                return x

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = nn.Conv2d(10, 10, 3)
                self.lin = nn.Linear(10, 10)
                self.sub = Submodule()

            def forward(self, x):
                return self.lin(x) + self.sub(x) + self.conv(x)

        m = torch.jit.script(MyModule())

        with self.capture_stdout() as out:
            print(m)

        f = FileCheck()
        f.check('MyModule')
        f.check('Conv2d')
        f.check('Linear')
        f.check('Submodule')
        f.run(out[0])

        self.assertEqual(m.original_name, 'MyModule')
Ejemplo n.º 5
0
    def test_error_stack(self):
        def d(x: int) -> int:
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        def a(x):
            return b(x)

        try:
            scripted = torch.jit.script(a)
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("def c(x)")
            checker.check("def b(x)")
            checker.check("def a(x)")
            checker.run(str(e))
Ejemplo n.º 6
0
    def test_error_stack_module(self):
        def d(x):
            # type: (int) -> int
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        class Submodule(torch.nn.Module):
            def __init__(self):
                super(Submodule, self).__init__()

            def forward(self, x):
                return b(x)

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.submodule = Submodule()

            def some_method(self, y):
                return y + self.submodule(y)

            def forward(self, x):
                return self.some_method(x)

        try:
            scripted = torch.jit.script(M())
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("'c' is being compiled since it was called from 'b'")
            checker.check("'b' is being compiled since it was called from")
            checker.run(str(e))