def assert_functionalization(self, func, inpt, *, reapply_views=False):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization(reapply_views=reapply_views)
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        self.assertEqual(inpt, torch._from_functional_tensor(
            input_functional))  # input mutations should still occur

        # Handle tests with multi-tensor outputs
        if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
            out_refs, out_functionals = list(out_ref), list(out_functional)
        else:
            out_refs, out_functionals = [out_ref], [out_functional]

        for out_ref_, out_functional_ in zip(out_refs, out_functionals):
            self.assertEqual(out_ref_.size(), out_functional_.size())
            torch._sync(out_functional_)
            out_functional_unwrapped = torch._from_functional_tensor(
                out_functional_)
            self.assertEqual(out_ref_, out_functional_unwrapped)
    def test_multiple_levels_of_wrapping(self):
        def f(x):
            # call an inplace op and have it get logged twice (by the outer + inner wrapper)
            x.add_(1)

        # Test 1: both the inner and outer wrapper are "functionalized"
        x_inner_and_outer_functional = torch._to_functional_tensor(
            InplaceLoggingTensor(
                torch._to_functional_tensor(LoggingTensor(torch.ones(4)))))

        with capture_logs() as logs:
            f(x_inner_and_outer_functional)

        # Since both wrappers were unctionalized, they both log "add"
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, 1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")

        # Test 2: only the inner wrapper is "functionalized"
        x_only_inner_functional = InplaceLoggingTensor(
            torch._to_functional_tensor(LoggingTensor(torch.ones(4))))

        with capture_logs() as logs:
            f(x_only_inner_functional)

        # Since only the inner wrapper is functionalized, then the inner (first) log is functionalized
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, 1)
$3 = torch._ops.aten.add_.Tensor($2, 1)""")

        # Test 3: only the inner wrapper is "functionalized"
        x_only_outer_functional = torch._to_functional_tensor(
            InplaceLoggingTensor(LoggingTensor(torch.ones(4))))

        with capture_logs() as logs:
            f(x_only_outer_functional)

        # Only the outer add_ is functionalized
        # Since only the outer wrapper is functionalized, then the outer (second) log is functionalized
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add_.Tensor($0, 1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
    def test_mixed_wrappers_invalid(self):
        x1_not_functional = torch.ones(4)
        x2_functional = torch._to_functional_tensor(torch.ones(4))

        # When dealing with mixed functional + non functional tensors,
        # normal_tensor.add_(functional_tensor) is not valid
        # because normal_tensor would need to be "promoted" to a functional tensor.
        with self.assertRaises(RuntimeError):
            x1_not_functional.add_(x2_functional)
 def wrapped(a):
     input_functional = torch._to_functional_tensor(a)
     torch._enable_functionalization(reapply_views=reapply_views)
     try:
         out = f(input_functional)
     finally:
         torch._disable_functionalization()
     torch._sync(input_functional)
     tree_map(torch._sync, out)
     out_unwrapped = tree_map(torch._from_functional_tensor, out)
     return out_unwrapped
Esempio n. 5
0
    def get_logs(self, func, inpt):
        input_clone_logging = LoggingTensor(inpt.clone())
        input_functional_logging = torch._to_functional_tensor(input_clone_logging)

        with capture_logs() as logs:
            log_input("input", input_clone_logging)
            torch._enable_functionalization()
            try:
                func(input_functional_logging)
            finally:
                torch._disable_functionalization()
        return logs
    def test_mixed_wrappers_valid(self):
        def f(x, y):
            z = x + y
            z.add_(1)
            return z

        x1_not_functional = LoggingTensor(torch.ones(4))
        x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))

        with capture_logs() as logs:
            y = f(x1_not_functional, x2_functional)

        # Make sure that functionalization ran the "+" kernel
        # with a functional + non-functional tensor, and wrapped the output appropriately.
        self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
Esempio n. 7
0
    def assert_functionalization(self, func, inpt):
        input_clone = inpt.clone()
        input_clone2 = inpt.clone()
        input_functional = torch._to_functional_tensor(input_clone2)

        # Compare outputs (and mutated inputs), with and without functionalization.
        out_ref = func(inpt)

        torch._enable_functionalization()
        try:
            out_functional = func(input_functional)
        finally:
            torch._disable_functionalization()

        # We need to sync the input tensors first, in case there are any queued mutations left.
        torch._sync(input_functional)
        torch._sync(out_functional)
        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
        self.assertEqual(inpt, torch._from_functional_tensor(input_functional))  # input mutations should still occur
Esempio n. 8
0
    def test_aliases_maintained_after_pass(self):
        def f(x):
            tmp = torch.ones(4, 2)
            y = x.view(4, 2)
            z = x.view(4, 2)
            y.add_(tmp)
            return y, z

        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
        torch._enable_functionalization()
        try:
            y, z = f(input_functional)
            torch._sync(y)
            torch._sync(z)
        finally:
            torch._disable_functionalization()

        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
        _y = torch._from_functional_tensor(y)
        _z = torch._from_functional_tensor(z)
        self.assertTrue(are_aliased(_y, _z))
Esempio n. 9
0
    def test_mixed_wrappers_valid(self):
        def f(x, y):
            z = x + y
            z.add_(1)
            return z

        x1_not_functional = LoggingTensor(torch.ones(4))
        x2_functional = torch._to_functional_tensor(
            LoggingTensor(torch.ones(4)))

        with capture_logs() as logs:
            y = f(x1_not_functional, x2_functional)

        # I think the alias trace is coming from the fact that x2 is technically *not*
        # a LoggingTensor (instead it *contains* a LoggingTensor), but x1 *is* a LoggingTensor.
        # The important thing here though is that functionalization ran the "+" kernel
        # with a functional + non-functional tensor, and wrapped the output appropriately.
        self.assertExpectedInline(
            '\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.alias.default($2)
$4 = torch._ops.aten.add.Tensor($3, tensor(1))""")
Esempio n. 10
0
 def test_save_for_backwards_segfault(self):
     inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
     inp.exp()