예제 #1
0
    def assert_functionalization(self, func, inpt, *, reapply_views=False):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization(reapply_views=reapply_views)
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        self.assertEqual(inpt, torch._from_functional_tensor(
            input_functional))  # input mutations should still occur

        # Handle tests with multi-tensor outputs
        if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
            out_refs, out_functionals = list(out_ref), list(out_functional)
        else:
            out_refs, out_functionals = [out_ref], [out_functional]

        for out_ref_, out_functional_ in zip(out_refs, out_functionals):
            self.assertEqual(out_ref_.size(), out_functional_.size())
            torch._sync(out_functional_)
            out_functional_unwrapped = torch._from_functional_tensor(
                out_functional_)
            self.assertEqual(out_ref_, out_functional_unwrapped)
예제 #2
0
 def wrapped(a):
     input_functional = torch._to_functional_tensor(a)
     torch._enable_functionalization(reapply_views=reapply_views)
     try:
         out = f(input_functional)
     finally:
         torch._disable_functionalization()
     torch._sync(input_functional)
     tree_map(torch._sync, out)
     out_unwrapped = tree_map(torch._from_functional_tensor, out)
     return out_unwrapped
예제 #3
0
    def get_logs(self, func, inpt):
        input_clone_logging = LoggingTensor(inpt.clone())
        input_functional_logging = torch._to_functional_tensor(input_clone_logging)

        with capture_logs() as logs:
            log_input("input", input_clone_logging)
            torch._enable_functionalization()
            try:
                func(input_functional_logging)
            finally:
                torch._disable_functionalization()
        return logs
예제 #4
0
    def assert_functionalization(self, func, inpt):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization()
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        torch._sync(out_functional)
        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
        self.assertEqual(inpt, torch._from_functional_tensor(input_functional))  # input mutations should still occur
예제 #5
0
    def test_aliases_maintained_after_pass(self):
        def f(x):
            tmp = torch.ones(4, 2)
            y = x.view(4, 2)
            z = x.view(4, 2)
            y.add_(tmp)
            return y, z

        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
        torch._enable_functionalization()
        try:
            y, z = f(input_functional)
            torch._sync(y)
            torch._sync(z)
        finally:
            torch._disable_functionalization()

        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
        _y = torch._from_functional_tensor(y)
        _z = torch._from_functional_tensor(z)
        self.assertTrue(are_aliased(_y, _z))