def test_aot_autograd_exhaustive(self, device, dtype, op): def f(args, kwargs): return op.op(*args, **kwargs) if not op.supports_autograd: return sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for sample_input in sample_inputs_itr: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs if not all([ isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args ]): self.skipTest("not all inputs are float tensors") if not all([ isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values() ]): self.skipTest("not all inputs are float tensors") continue t = f(args, kwargs) if isinstance(t, tuple): self.skipTest("output is a tuple") continue def reset_grads(): def f(x): x.grad = None pytree.tree_map(f, args) def get_grads(args): return pytree.tree_map(lambda x: x.grad, args) compiled_f = compiled_function(f, nop, nop) reset_grads() compiled_f(args, kwargs).sum().backward() compiled_grad = get_grads(args) reset_grads() f(args, kwargs).sum().backward() orig_grad = get_grads(args) self.assertEqual(orig_grad, compiled_grad) def create_new_arg(x): return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad) args = pytree.tree_map(create_new_arg, args) reset_grads() compiled_f(args, kwargs).sum().backward() compiled_grad = get_grads(args) reset_grads() f(args, kwargs).sum().backward() orig_grad = get_grads(args) self.assertEqual(orig_grad, compiled_grad)
def test_recompute_partitioning(self): def fn(a, b): return torch.sin(torch.sin(a)) + b # Reference calculation ref_a = torch.rand(10, 10, requires_grad=True) ref_b = torch.rand(10, 10, requires_grad=True) ref = fn(ref_a, ref_b) ref.sum().backward() # Compiled function calculation res_a = ref_a.clone().detach().requires_grad_(True) res_b = ref_b.clone().detach().requires_grad_(True) def compile_fn(x, _): return x compiled_fn = compiled_function(fn, compile_fn, compile_fn, min_cut_rematerialization_partition) res = compiled_fn(res_a, res_b) res.sum().backward() assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)