def assert_functionalization(self, func, inpt, *, reapply_views=False): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization(reapply_views=reapply_views) try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) self.assertEqual(inpt, torch._from_functional_tensor( input_functional)) # input mutations should still occur # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): out_refs, out_functionals = list(out_ref), list(out_functional) else: out_refs, out_functionals = [out_ref], [out_functional] for out_ref_, out_functional_ in zip(out_refs, out_functionals): self.assertEqual(out_ref_.size(), out_functional_.size()) torch._sync(out_functional_) out_functional_unwrapped = torch._from_functional_tensor( out_functional_) self.assertEqual(out_ref_, out_functional_unwrapped)
def test_multiple_levels_of_wrapping(self): def f(x): # call an inplace op and have it get logged twice (by the outer + inner wrapper) x.add_(1) # Test 1: both the inner and outer wrapper are "functionalized" x_inner_and_outer_functional = torch._to_functional_tensor( InplaceLoggingTensor( torch._to_functional_tensor(LoggingTensor(torch.ones(4))))) with capture_logs() as logs: f(x_inner_and_outer_functional) # Since both wrappers were unctionalized, they both log "add" self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, 1) $3 = torch._ops.aten.add.Tensor($2, 1)""") # Test 2: only the inner wrapper is "functionalized" x_only_inner_functional = InplaceLoggingTensor( torch._to_functional_tensor(LoggingTensor(torch.ones(4)))) with capture_logs() as logs: f(x_only_inner_functional) # Since only the inner wrapper is functionalized, then the inner (first) log is functionalized self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, 1) $3 = torch._ops.aten.add_.Tensor($2, 1)""") # Test 3: only the inner wrapper is "functionalized" x_only_outer_functional = torch._to_functional_tensor( InplaceLoggingTensor(LoggingTensor(torch.ones(4)))) with capture_logs() as logs: f(x_only_outer_functional) # Only the outer add_ is functionalized # Since only the outer wrapper is functionalized, then the outer (second) log is functionalized self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add_.Tensor($0, 1) $3 = torch._ops.aten.add.Tensor($2, 1)""")
def test_mixed_wrappers_invalid(self): x1_not_functional = torch.ones(4) x2_functional = torch._to_functional_tensor(torch.ones(4)) # When dealing with mixed functional + non functional tensors, # normal_tensor.add_(functional_tensor) is not valid # because normal_tensor would need to be "promoted" to a functional tensor. with self.assertRaises(RuntimeError): x1_not_functional.add_(x2_functional)
def wrapped(a): input_functional = torch._to_functional_tensor(a) torch._enable_functionalization(reapply_views=reapply_views) try: out = f(input_functional) finally: torch._disable_functionalization() torch._sync(input_functional) tree_map(torch._sync, out) out_unwrapped = tree_map(torch._from_functional_tensor, out) return out_unwrapped
def get_logs(self, func, inpt): input_clone_logging = LoggingTensor(inpt.clone()) input_functional_logging = torch._to_functional_tensor(input_clone_logging) with capture_logs() as logs: log_input("input", input_clone_logging) torch._enable_functionalization() try: func(input_functional_logging) finally: torch._disable_functionalization() return logs
def test_mixed_wrappers_valid(self): def f(x, y): z = x + y z.add_(1) return z x1_not_functional = LoggingTensor(torch.ones(4)) x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4))) with capture_logs() as logs: y = f(x1_not_functional, x2_functional) # Make sure that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. self.assertExpectedInline('\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1) $3 = torch._ops.aten.add.Tensor($2, 1)""")
def assert_functionalization(self, func, inpt): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization() try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) torch._sync(out_functional) self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
def test_aliases_maintained_after_pass(self): def f(x): tmp = torch.ones(4, 2) y = x.view(4, 2) z = x.view(4, 2) y.add_(tmp) return y, z input_functional = torch._to_functional_tensor(torch.ones(4, 2)) torch._enable_functionalization() try: y, z = f(input_functional) torch._sync(y) torch._sync(z) finally: torch._disable_functionalization() # y and z are aliases inside of the function, and that aliasing relationship should be maintained. _y = torch._from_functional_tensor(y) _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z))
def test_mixed_wrappers_valid(self): def f(x, y): z = x + y z.add_(1) return z x1_not_functional = LoggingTensor(torch.ones(4)) x2_functional = torch._to_functional_tensor( LoggingTensor(torch.ones(4))) with capture_logs() as logs: y = f(x1_not_functional, x2_functional) # I think the alias trace is coming from the fact that x2 is technically *not* # a LoggingTensor (instead it *contains* a LoggingTensor), but x1 *is* a LoggingTensor. # The important thing here though is that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. self.assertExpectedInline( '\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1) $3 = torch._ops.aten.alias.default($2) $4 = torch._ops.aten.add.Tensor($3, tensor(1))""")
def test_save_for_backwards_segfault(self): inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True) inp.exp()