def test_module_repr(self): class Submodule(nn.Module): def forward(self, x): return x class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = nn.Conv2d(10, 10, 3) self.lin = nn.Linear(10, 10) self.sub = Submodule() def forward(self, x): return self.lin(x) + self.sub(x) + self.conv(x) m = torch.jit.script(MyModule()) with self.capture_stdout() as out: print(m) f = FileCheck() f.check('MyModule') f.check('Conv2d') f.check('Linear') f.check('Submodule') f.run(out[0]) self.assertEqual(m.original_name, 'MyModule')
def validate_transformed_module(module_name, pattern_count_map, data_shape, prepack_removal=False): scripted_model = torch.jit.script(module_name()) scripted_model.eval() input_data = torch.rand(data_shape) ref_result = scripted_model(input_data) torch._C._jit_pass_insert_prepacked_ops(scripted_model._c) if (prepack_removal): scripted_model._c = torch._C._freeze_module(scripted_model._c) torch._C._jit_pass_fold_prepacking_ops(scripted_model._c) buffer = io.BytesIO() torch.jit.save(scripted_model, buffer) buffer.seek(0) deserialized_scripted_model = torch.jit.load(buffer) file_check = FileCheck() for pattern, v in pattern_count_map.items(): if (v == 0): file_check.check(pattern) elif (v == -1): file_check.check_not(pattern) else: file_check.check_count(pattern, v, exactly=True) file_check.run(deserialized_scripted_model.graph) xnnpack_result = deserialized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_fake_dispatch_keys(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([4]) f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU") f.run(torch._C._dispatch_key_set(x)) with torch.inference_mode(): x = torch.rand([4]) y = x + x FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y)) FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_torchbind_return_instance(self): def foo(): ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"]) return ss scripted = torch.jit.script(foo) # Ensure we are creating the object and calling __init__ # rather than calling the __init__wrapper nonsense fc = FileCheck().check('prim::CreateObject()')\ .check('prim::CallMethod[name="__init__"]') fc.run(str(scripted.graph)) out = scripted() self.assertEqual(out.pop(), "mom") self.assertEqual(out.pop(), "hi")
def test_error_stack_class(self): class X(object): def bad_fn(self): import pdb # noqa: F401 def fn(x) -> X: return X(10) try: torch.jit.script(fn) except Exception as e: checker = FileCheck() checker.check("import statements") checker.check("is being compiled since it was called from") checker.run(str(e))
def test_error_stack(self): def d(x: int) -> int: return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) def a(x): return b(x) try: scripted = torch.jit.script(a) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("def c(x)") checker.check("def b(x)") checker.check("def a(x)") checker.run(str(e))
def test_error_stack_module(self): def d(x): # type: (int) -> int return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) class Submodule(torch.nn.Module): def __init__(self): super(Submodule, self).__init__() def forward(self, x): return b(x) class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.submodule = Submodule() def some_method(self, y): return y + self.submodule(y) def forward(self, x): return self.some_method(x) try: scripted = torch.jit.script(M()) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("'c' is being compiled since it was called from 'b'") checker.check("'b' is being compiled since it was called from") checker.run(str(e))