Пример #1
0
    def test_reinplace_index_mutation(self):
        def f():
            a = torch.zeros(4, 4, 4)
            a[:, 2:] = torch.ones(4, 2, 4)
            return a

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
    slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807);  slice_tensor = None
    slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807);  slice_tensor_2 = None
    copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones);  slice_tensor_3 = ones = None
    return zeros
    """)
Пример #2
0
    def test_out_node_updated(self):
        def f():
            x = torch.zeros(2, 2)
            y = x.diagonal()
            y_updated = y.add(1)
            z = torch.diagonal_scatter(x, y_updated)
            # reinplace needs to know to replace output [z] with [x]
            return [z]

        if not HAS_FUNCTIONALIZATION:
            return
        f2 = reinplace(make_fx(functionalize(f))())
        expected_out = f()
        actual_out = f2()
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self):
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
    diagonal_default = torch.ops.aten.diagonal.default(zeros)
    add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1);  diagonal_default = None
    return [zeros]
    """)
Пример #3
0
    def test_reinplace_scatter_twice(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            b = a[:, 1]
            c = b[1]
            c.add_(1)
            return a

        if not HAS_FUNCTIONALIZATION:
            return

        inpt = torch.ones(4, 4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int = torch.ops.aten.select.int(slice_tensor, 1, 1);  slice_tensor = None
    select_int_1 = torch.ops.aten.select.int(select_int, 0, 1);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1);  select_int_1 = None
    slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
    select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1);  slice_tensor_1 = None
    return clone_default
    """)
Пример #4
0
        def forward(ctx, *flat_tensor_args):
            nonlocal compiled_fw, compiled_bw, num_outs
            # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
            # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
            old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
            if compiled_fw is None:
                flat_tensor_args = pytree.tree_map(
                    lambda x: x.detach().requires_grad_(x.requires_grad)
                    if isinstance(x, Tensor) else x, flat_tensor_args)
                fake_mode = FakeTensorMode.push(
                ) if config.use_fake_tensor else nullcontext()
                with preserve_rng_state(), fake_mode as mode:
                    # Set input tensors that require grad to leaves
                    fake_flat_tensor_args = pytree.tree_map(
                        lambda x: mode.from_tensor(x) if mode else x
                        if isinstance(x, Tensor) else x, flat_tensor_args)
                    with torch.set_grad_enabled(grad_state):
                        out = flat_fn(*fake_flat_tensor_args)
                    out = pytree.tree_map(
                        lambda x: x.detach().contiguous()
                        if isinstance(x, Tensor) else x, out)

                    if isinstance(out, (list, tuple)):
                        num_outs = len(out)
                    else:
                        num_outs = 1

                    joint_inputs = (fake_flat_tensor_args, out)
                    aot_decompositions = {
                        **aot_autograd_decompositions,
                        **decompositions
                    }
                    with torch.set_grad_enabled(grad_state):
                        fx_g = make_fx(joint_forward_backward,
                                       aot_decompositions)(*joint_inputs)

                        if config.use_functionalize:
                            # Functionalize the foward backward graph. First create a
                            # fake fn to make functionalize happy
                            def fake_fn(primals, tangents):
                                return fx_g(primals, tangents)

                            fx_g = make_fx(
                                functionalize(fake_fn))(*joint_inputs)
                fw_module, bw_module = partition_fn(fx_g, joint_inputs)
                # print(fw_module.code, bw_module.code)

                compiled_fw = fw_compiler(fw_module, flat_tensor_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

                bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
                compiled_bw = bw_compiler(bw_module, bw_args)
            else:
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
            torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])
Пример #5
0
    def test_reinplace_scatter_op(self):
        def f(a_):
            # for now, don't test mutations to inputs
            a = a_.clone()
            e = a.view(-1)
            b = a.view(-1)
            c = b[0]
            d = c.view(-1)
            d.add_(1)
            return a + e

        if not HAS_FUNCTIONALIZATION:
            return
        inpt = torch.ones(4)
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
        expected_out = f(inpt)
        actual_out = f2(inpt)
        self.assertEqual(actual_out, expected_out)
        # NOTE: one slight pessimization here is the fact that
        # there are a bunch of redundant views in the graph.
        # Technically, half of these views are duplicates that we could de-dup.
        # This shouldn't really hurt performance though, since creating an extra view
        # is effectively just moving some metadata around (and allocating a new TensorImpl).
        # We can/should update the pass in the future to clean this up.
        self.assertExpectedInline(
            f2.code, """\



def forward(self, a__1):
    clone_default = torch.ops.aten.clone.default(a__1);  a__1 = None
    view_default = torch.ops.aten.view.default(clone_default, [-1])
    view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
    select_int = torch.ops.aten.select.int(view_default_1, 0, 0);  view_default_1 = None
    view_default_2 = torch.ops.aten.view.default(select_int, [-1]);  select_int = None
    add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
    view_default_3 = torch.ops.aten.view.default(clone_default, [-1]);  clone_default = None
    select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
    view_default_4 = torch.ops.aten.view.default(view_default_2, []);  view_default_2 = None
    view_default_5 = torch.ops.aten.view.default(view_default_3, [4]);  view_default_3 = None
    view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
    add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6);  view_default_6 = None
    return view_default_5
    """)
Пример #6
0
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
    joint_forward_backward = create_joint_forward_backward(flat_fn)
    out = flat_fn(*flat_args)
    out = pytree.tree_map(
        lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
        out,
    )

    if isinstance(out, (list, tuple)):
        _num_outs = len(out)
    else:
        _num_outs = 1

    joint_inputs = (flat_args, out)
    fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs)

    if config.use_functionalize:
        # Functionalize the foward backward graph. First create a
        # fake fn to make functionalize happy
        def fake_fn(primals, tangents):
            return fx_g(primals, tangents)

        fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)

    if config.debug_joint:
        print(fx_g.code)

    with torch.no_grad():
        with track_graph_compiling("joint"):
            fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs)

        if config.debug_graphs:
            print(fw_module.code, bw_module.code)

        with track_graph_compiling("forward"):
            compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args)

        if config.debug_partitioner:
            fw_outs = call_func_with_args(compiled_fw_func, flat_args)
            activation_sizes = 0
            for out in fw_outs[_num_outs:]:
                if isinstance(out, torch.Tensor):
                    activation_sizes += out.storage().nbytes()
            print(f"Real Activations Stored(GB): {activation_sizes/1e9}")

    class CompiledFunction(torch.autograd.Function):
        compiled_fw = compiled_fw_func
        compiled_bw = None
        num_outs = _num_outs

        @staticmethod
        @disable_torchdynamo
        def forward(ctx, *flat_tensor_args):
            fw_outs = call_func_with_args(
                CompiledFunction.compiled_fw, flat_tensor_args
            )
            num_outs = CompiledFunction.num_outs
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])

        @staticmethod
        @disable_torchdynamo
        def backward(ctx, *flat_args):
            contiguous_args = [t.contiguous() for t in flat_args]
            all_args = list(ctx.saved_tensors) + list(contiguous_args)
            if CompiledFunction.compiled_bw is None:
                with track_graph_compiling("backward", True):
                    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                        bw_module, all_args
                    )
            ctx.maybe_clear_saved_tensors()
            out = call_func_with_args(
                CompiledFunction.compiled_bw, all_args, steal_args=True
            )

            return tuple(out)

    return CompiledFunction.apply