def assert_functionalization(self, func, inpt, *, reapply_views=False): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization(reapply_views=reapply_views) try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) self.assertEqual(inpt, torch._from_functional_tensor( input_functional)) # input mutations should still occur # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): out_refs, out_functionals = list(out_ref), list(out_functional) else: out_refs, out_functionals = [out_ref], [out_functional] for out_ref_, out_functional_ in zip(out_refs, out_functionals): self.assertEqual(out_ref_.size(), out_functional_.size()) torch._sync(out_functional_) out_functional_unwrapped = torch._from_functional_tensor( out_functional_) self.assertEqual(out_ref_, out_functional_unwrapped)
def wrapped(a): input_functional = torch._to_functional_tensor(a) torch._enable_functionalization(reapply_views=reapply_views) try: out = f(input_functional) finally: torch._disable_functionalization() torch._sync(input_functional) tree_map(torch._sync, out) out_unwrapped = tree_map(torch._from_functional_tensor, out) return out_unwrapped
def get_logs(self, func, inpt): input_clone_logging = LoggingTensor(inpt.clone()) input_functional_logging = torch._to_functional_tensor(input_clone_logging) with capture_logs() as logs: log_input("input", input_clone_logging) torch._enable_functionalization() try: func(input_functional_logging) finally: torch._disable_functionalization() return logs
def assert_functionalization(self, func, inpt): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization() try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) torch._sync(out_functional) self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
def test_aliases_maintained_after_pass(self): def f(x): tmp = torch.ones(4, 2) y = x.view(4, 2) z = x.view(4, 2) y.add_(tmp) return y, z input_functional = torch._to_functional_tensor(torch.ones(4, 2)) torch._enable_functionalization() try: y, z = f(input_functional) torch._sync(y) torch._sync(z) finally: torch._disable_functionalization() # y and z are aliases inside of the function, and that aliasing relationship should be maintained. _y = torch._from_functional_tensor(y) _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z))