def test_meta_tensor_inplace_op(self): # Following module results in inplace ops while tracing. The test checks # that the meta tensor information is stored for inplace ops. class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.nn.Parameter(torch.randn(3072, 768, requires_grad=True)) self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) def forward(self, add_4): linear_4 = torch.nn.functional.linear(add_4, self.weight, bias=self.bias) gelu = torch.nn.functional.gelu(linear_4) return gelu def check_meta_tensor(fx_g, _): for node in fx_g.graph.nodes: if node.op != 'output': assert 'tensor_meta' in node.meta return fx_g inp0 = torch.randn(16, 128, 768, requires_grad=True) inputs = [inp0, ] mod = MockModule().to(device="cpu") aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) aot_mod(*inputs)
def verify_aot_autograd(self, f, inp): if isinstance(f, nn.Module): compiled_f = aot_module(f, nop) else: compiled_f = aot_function(f, nop) ref_out, ref_grad = _outs_and_grads(f, inp) test_out, test_grad = _outs_and_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad)