def assert_functionalization(self, func, inpt, *, reapply_views=False): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization(reapply_views=reapply_views) try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) self.assertEqual(inpt, torch._from_functional_tensor( input_functional)) # input mutations should still occur # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): out_refs, out_functionals = list(out_ref), list(out_functional) else: out_refs, out_functionals = [out_ref], [out_functional] for out_ref_, out_functional_ in zip(out_refs, out_functionals): self.assertEqual(out_ref_.size(), out_functional_.size()) torch._sync(out_functional_) out_functional_unwrapped = torch._from_functional_tensor( out_functional_) self.assertEqual(out_ref_, out_functional_unwrapped)
def wrapped(a): input_functional = torch._to_functional_tensor(a) torch._enable_functionalization(reapply_views=reapply_views) try: out = f(input_functional) finally: torch._disable_functionalization() torch._sync(input_functional) tree_map(torch._sync, out) out_unwrapped = tree_map(torch._from_functional_tensor, out) return out_unwrapped
def assert_functionalization(self, func, inpt): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization() try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) torch._sync(out_functional) self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
def test_aliases_maintained_after_pass(self): def f(x): tmp = torch.ones(4, 2) y = x.view(4, 2) z = x.view(4, 2) y.add_(tmp) return y, z input_functional = torch._to_functional_tensor(torch.ones(4, 2)) torch._enable_functionalization() try: y, z = f(input_functional) torch._sync(y) torch._sync(z) finally: torch._disable_functionalization() # y and z are aliases inside of the function, and that aliasing relationship should be maintained. _y = torch._from_functional_tensor(y) _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z))
def _functorch_str(tensor, *, tensor_contents=None): level = _C.maybe_get_level(tensor) if level == -1: return _old_str(tensor) if _C.is_functionaltensor(tensor): # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure # that it's up to date first torch._sync(tensor) value = _C.get_unwrapped(tensor) dl_enabled = _C.tls_set_is_included() try: # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys if (dl_enabled): _C._set_dynamic_layer_keys_included(False) value_repr = repr(value) finally: # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys if (dl_enabled): _C._set_dynamic_layer_keys_included(True) if _C.is_batchedtensor(tensor): bdim = _C.maybe_get_bdim(tensor) assert bdim != -1 return (f'BatchedTensor(lvl={level}, bdim={bdim}, value=\n' f'{prep_value(value_repr)}\n' f')') if _C.is_gradtrackingtensor(tensor): return (f'GradTrackingTensor(lvl={level}, value=\n' f'{prep_value(value_repr)}\n' f')') if _C.is_functionaltensor(tensor): return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' raise ValueError( "We don't know how to print this, please file us an issue")