def assert_functionalization(self, func, inpt, *, reapply_views=False):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization(reapply_views=reapply_views)
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        self.assertEqual(inpt, torch._from_functional_tensor(
            input_functional))  # input mutations should still occur

        # Handle tests with multi-tensor outputs
        if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
            out_refs, out_functionals = list(out_ref), list(out_functional)
        else:
            out_refs, out_functionals = [out_ref], [out_functional]

        for out_ref_, out_functional_ in zip(out_refs, out_functionals):
            self.assertEqual(out_ref_.size(), out_functional_.size())
            torch._sync(out_functional_)
            out_functional_unwrapped = torch._from_functional_tensor(
                out_functional_)
            self.assertEqual(out_ref_, out_functional_unwrapped)
 def wrapped(a):
     input_functional = torch._to_functional_tensor(a)
     torch._enable_functionalization(reapply_views=reapply_views)
     try:
         out = f(input_functional)
     finally:
         torch._disable_functionalization()
     torch._sync(input_functional)
     tree_map(torch._sync, out)
     out_unwrapped = tree_map(torch._from_functional_tensor, out)
     return out_unwrapped
    def assert_functionalization(self, func, inpt):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization()
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        torch._sync(out_functional)
        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
        self.assertEqual(inpt, torch._from_functional_tensor(input_functional))  # input mutations should still occur
Example #4
0
    def test_aliases_maintained_after_pass(self):
        def f(x):
            tmp = torch.ones(4, 2)
            y = x.view(4, 2)
            z = x.view(4, 2)
            y.add_(tmp)
            return y, z

        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
        torch._enable_functionalization()
        try:
            y, z = f(input_functional)
            torch._sync(y)
            torch._sync(z)
        finally:
            torch._disable_functionalization()

        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
        _y = torch._from_functional_tensor(y)
        _z = torch._from_functional_tensor(z)
        self.assertTrue(are_aliased(_y, _z))
Example #5
0
def _functorch_str(tensor, *, tensor_contents=None):
    level = _C.maybe_get_level(tensor)
    if level == -1:
        return _old_str(tensor)

    if _C.is_functionaltensor(tensor):
        # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
        # that it's up to date first
        torch._sync(tensor)

    value = _C.get_unwrapped(tensor)
    dl_enabled = _C.tls_set_is_included()
    try:
        # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
        if (dl_enabled):
            _C._set_dynamic_layer_keys_included(False)
        value_repr = repr(value)
    finally:
        # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
        if (dl_enabled):
            _C._set_dynamic_layer_keys_included(True)

    if _C.is_batchedtensor(tensor):
        bdim = _C.maybe_get_bdim(tensor)
        assert bdim != -1
        return (f'BatchedTensor(lvl={level}, bdim={bdim}, value=\n'
                f'{prep_value(value_repr)}\n'
                f')')
    if _C.is_gradtrackingtensor(tensor):
        return (f'GradTrackingTensor(lvl={level}, value=\n'
                f'{prep_value(value_repr)}\n'
                f')')
    if _C.is_functionaltensor(tensor):
        return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})'

    raise ValueError(
        "We don't know how to print this, please file us an issue")