def test_aot_module_simplified(self): class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): return (self.linear(x) + y, ) mod = MockModule() mod.zero_grad() x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] ref = mod(*inputs) ref[0].sum().backward() aot_mod = aot_module_simplified(mod, nop) aot_mod.zero_grad() res = aot_mod(*cloned_inputs) res[0].sum().backward() assert torch.allclose(ref[0], res[0]) assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
def test_aot_module_simplified_preserves_stack_trace(self): class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): z = self.linear(x) z = z + y z = z.relu() return (z, ) tracer = torch.fx.Tracer() tracer.record_stack_traces = True graph = tracer.trace(MockModule()) mod = torch.fx.GraphModule(tracer.root, graph) for node in mod.graph.nodes: if node.op == 'output': continue self.assertTrue(node.stack_trace is not None) assert 'test_pythonkey.py' in node.stack_trace def assert_compiler(gm: torch.fx.GraphModule, _): for node in gm.graph.nodes: if node.op == 'output' or node.op == 'placeholder': continue self.assertTrue(node.stack_trace is not None) assert 'test_pythonkey.py' in node.stack_trace return gm.forward # return a python callable aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler) x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] res = aot_mod(*inputs)