def validate_transformed_module(module_name, pattern_count_map, data_shape, prepack_removal=False): scripted_model = torch.jit.script(module_name()) scripted_model.eval() input_data = torch.rand(data_shape) ref_result = scripted_model(input_data) torch._C._jit_pass_insert_prepacked_ops(scripted_model._c) if (prepack_removal): scripted_model._c = torch._C._freeze_module(scripted_model._c) torch._C._jit_pass_fold_prepacking_ops(scripted_model._c) buffer = io.BytesIO() torch.jit.save(scripted_model, buffer) buffer.seek(0) deserialized_scripted_model = torch.jit.load(buffer) file_check = FileCheck() for pattern, v in pattern_count_map.items(): if (v == 0): file_check.check(pattern) elif (v == -1): file_check.check_not(pattern) else: file_check.check_count(pattern, v, exactly=True) file_check.run(deserialized_scripted_model.graph) xnnpack_result = deserialized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_lstm_traced_cpu(self): inputs = get_lstm_inputs('cpu') try: ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) FileCheck.check("FusionGroup").run(str(graph)) except RuntimeError as e: if 'Failed to compile' in e.args[0]: warnings.warn('CPU fuser test has failed! This is not a hard failure, ' 'because the kernels sometimes trigger bugs in compilers ' '(most notably GCC 7.2).') raise unittest.SkipTest('Failed to compile') else: raise
def test_error_stack_class(self): class X(object): def bad_fn(self): import pdb # noqa: F401 def fn(x) -> X: return X(10) try: torch.jit.script(fn) except Exception as e: checker = FileCheck() checker.check("import statements") checker.check("is being compiled since it was called from") checker.run(str(e))
def test_module_repr(self): class Submodule(nn.Module): def forward(self, x): return x class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = nn.Conv2d(10, 10, 3) self.lin = nn.Linear(10, 10) self.sub = Submodule() def forward(self, x): return self.lin(x) + self.sub(x) + self.conv(x) m = torch.jit.script(MyModule()) with self.capture_stdout() as out: print(m) f = FileCheck() f.check('MyModule') f.check('Conv2d') f.check('Linear') f.check('Submodule') f.run(out[0]) self.assertEqual(m.original_name, 'MyModule')
def test_error_stack(self): def d(x: int) -> int: return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) def a(x): return b(x) try: scripted = torch.jit.script(a) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("def c(x)") checker.check("def b(x)") checker.check("def a(x)") checker.run(str(e))
def test_error_stack_module(self): def d(x): # type: (int) -> int return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) class Submodule(torch.nn.Module): def __init__(self): super(Submodule, self).__init__() def forward(self, x): return b(x) class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.submodule = Submodule() def some_method(self, y): return y + self.submodule(y) def forward(self, x): return self.some_method(x) try: scripted = torch.jit.script(M()) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("'c' is being compiled since it was called from 'b'") checker.check("'b' is being compiled since it was called from") checker.run(str(e))