def test_reinplace_different_metadata(self):
        def f(a_):
            a = a_.clone()
            b = a + 1
            # Naively, we shouldn't try to inplace the .ge() call,
            # because that would require resizing "b" (from a float to a bool tensor).
            c = torch.ge(b, a)
            return c

        inpt = torch.ones(4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        # The .ge() should not be reinplaced.
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    add_tensor = torch.ops.aten.add.Tensor(clone_default, 1)
    ge_tensor = torch.ops.aten.ge.Tensor(add_tensor, clone_default);  add_tensor = clone_default = None
    return ge_tensor
    """)
    def test_reinplace_index_mutation(self):
        def f():
            a = torch.zeros(4, 4, 4)
            a[:, 2:] = torch.ones(4, 2, 4)
            return a

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
    slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807);  slice_tensor = None
    slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807);  slice_tensor_2 = None
    copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones);  slice_tensor_3 = ones = None
    return zeros
    """)
    def test_reinplace_with_view(self):
        def f(x):
            a = x.clone()
            a_view = a.view(-1)
            # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
            b = a.add(1)
            # Second add() is fine to re-inplace
            c = a_view.add(1)
            return c

        inpt = torch.ones(2)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, x_1):
    clone_default = torch.ops.aten.clone.default(x_1);  x_1 = None
    view_default = torch.ops.aten.view.default(clone_default, [-1])
    add_tensor = torch.ops.aten.add.Tensor(clone_default, 1);  clone_default = None
    add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
    return view_default
    """)
    def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
        def f(a_):
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c_updated = c.add(1)
            bad_mirror_of_b = a.as_strided((4, ), (4, ), 0)
            # The first arg to select_scatter points to a different than c's base.
            # This makes it invalid to re-inplace.
            b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
            return b_updated

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        # self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1);  select_int_1 = None
    as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0);  clone_default = None
    select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1)
    copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor);  select_int_2 = add_tensor = None
    return as_strided_default
    """)  # noqa: B950
    def test_out_node_updated(self):
        def f():
            x = torch.zeros(2, 2)
            y = x.diagonal()
            y_updated = y.add(1)
            z = torch.diagonal_scatter(x, y_updated)
            # reinplace needs to know to replace output [z] with [x]
            return [z]

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
    diagonal_default = torch.ops.aten.diagonal.default(zeros)
    add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1);  diagonal_default = None
    return [zeros]
    """)
    def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
        def f(a_):
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c_updated = c.add(1)
            good_mirror_of_b = a.as_strided((4, ), (4, ), 1)
            # The first arg to select_scatter is an equivalent view to b.
            # However, the select_scatter call below tries to put c_updated
            # into a different slice of "b" than what "c" currently occupies.
            #
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
            return b_updated

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1);  select_int_1 = None
    as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1);  clone_default = None
    select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0)
    copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor);  select_int_2 = add_tensor = None
    return as_strided_default
    """)  # noqa: B950
    def test_reinplace_scatter_twice(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c.add_(1)
            return a

        if not HAS_FUNCTIONALIZATION:
            return

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1);  select_int_1 = None
    slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1);  slice_tensor_1 = None
    return clone_default
    """)
    def test_reinplace_scatter_op(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            e = a.view(-1)
            b = a.view(-1)
            c = b[0]
            d = c.view(-1)
            d.add_(1)
            return a + e

        if not HAS_FUNCTIONALIZATION:
            return
        inpt = torch.ones(4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        # NOTE: one slight pessimization here is the fact that
        # there are a bunch of redundant views in the graph.
        # Technically, half of these views are duplicates that we could de-dup.
        # This shouldn't really hurt performance though, since creating an extra view
        # is effectively just moving some metadata around (and allocating a new TensorImpl).
        # We can/should update the pass in the future to clean this up.
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    view_default = torch.ops.aten.view.default(clone_default, [-1])
    view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
    select_int = torch.ops.aten.select.int(view_default_1, 0, 0);  view_default_1 = None
    view_default_2 = torch.ops.aten.view.default(select_int, [-1]);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
    view_default_3 = torch.ops.aten.view.default(clone_default, [-1]);  clone_default = None
    select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
    view_default_4 = torch.ops.aten.view.default(view_default_2, []);  view_default_2 = None
    view_default_5 = torch.ops.aten.view.default(view_default_3, [4]);  view_default_3 = None
    view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
    add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6);  view_default_6 = None
    return view_default_5
    """)
    def test_reinplace_basic(self):
        # Basic test: the out-of-place add() call should be converted
        # into add_()
        def f(x):
            a = x.clone()
            b = a.add(1)
            return b

        inpt = torch.ones(2)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, x_1):
    clone_default = torch.ops.aten.clone.default(x_1);  x_1 = None
    add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1)
    return clone_default
    """)
    def test_reinplace_overlapping_memory(self):
        def f(a_):
            a = a_.clone()
            b = a.expand(4, 4)
            # Can't reinplace because b has overlapping memory.
            c = b.add(1)
            return c

        inpt = torch.ones(1)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    expand_default = torch.ops.aten.expand.default(clone_default, [4, 4]);  clone_default = None
    add_tensor = torch.ops.aten.add.Tensor(expand_default, 1);  expand_default = None
    return add_tensor
    """)
    def test_reinplace_scatter_twice_with_different_view_op_valid(self):
        def f(a_):
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c_updated = c.add(1)
            good_mirror_of_b = a.as_strided((4, ), (4, ), 1)
            # good_mirror_of_b points to the same region of memory as b.
            # and this scatter op below tries to scatter c_updated into the same region
            # that c currently takes up.
            # reinplacing logic checks this by confirming that:
            #   c_updated
            #   good_mirror_of_b.select(0, 1)
            # have the same size/stride/storage_offset.
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
            return b_updated

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(f)(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1);  select_int_1 = None
    as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1);  clone_default = None
    return as_strided_default
    """)