def test_reinplace_index_mutation(self): def f(): a = torch.zeros(4, 4, 4) a[:, 2:] = torch.ones(4, 2, 4) return a if not HAS_FUNCTIONALIZATION: return f2 = reinplace(make_fx(functionalize(f))()) expected_out = f() actual_out = f2() self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807); slice_tensor_2 = None copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones); slice_tensor_3 = ones = None return zeros """)
def test_out_node_updated(self): def f(): x = torch.zeros(2, 2) y = x.diagonal() y_updated = y.add(1) z = torch.diagonal_scatter(x, y_updated) # reinplace needs to know to replace output [z] with [x] return [z] if not HAS_FUNCTIONALIZATION: return f2 = reinplace(make_fx(functionalize(f))()) expected_out = f() actual_out = f2() self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_default = torch.ops.aten.diagonal.default(zeros) add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None return [zeros] """)
def test_reinplace_scatter_twice(self): def f(a_): # for now, don't test mutations to inputs a = a_.clone() b = a[:, 1] c = b[1] c.add_(1) return a if not HAS_FUNCTIONALIZATION: return inpt = torch.ones(4, 4) f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) self.assertExpectedInline( f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None return clone_default """)
def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) if compiled_fw is None: flat_tensor_args = pytree.tree_map( lambda x: x.detach().requires_grad_(x.requires_grad) if isinstance(x, Tensor) else x, flat_tensor_args) fake_mode = FakeTensorMode.push( ) if config.use_fake_tensor else nullcontext() with preserve_rng_state(), fake_mode as mode: # Set input tensors that require grad to leaves fake_flat_tensor_args = pytree.tree_map( lambda x: mode.from_tensor(x) if mode else x if isinstance(x, Tensor) else x, flat_tensor_args) with torch.set_grad_enabled(grad_state): out = flat_fn(*fake_flat_tensor_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out) if isinstance(out, (list, tuple)): num_outs = len(out) else: num_outs = 1 joint_inputs = (fake_flat_tensor_args, out) aot_decompositions = { **aot_autograd_decompositions, **decompositions } with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)(*joint_inputs) if config.use_functionalize: # Functionalize the foward backward graph. First create a # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) fx_g = make_fx( functionalize(fake_fn))(*joint_inputs) fw_module, bw_module = partition_fn(fx_g, joint_inputs) # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) torch._C._jit_set_autocast_mode(old_jit_autocast_flag) ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs])
def test_reinplace_scatter_op(self): def f(a_): # for now, don't test mutations to inputs a = a_.clone() e = a.view(-1) b = a.view(-1) c = b[0] d = c.view(-1) d.add_(1) return a + e if not HAS_FUNCTIONALIZATION: return inpt = torch.ones(4) f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) expected_out = f(inpt) actual_out = f2(inpt) self.assertEqual(actual_out, expected_out) # NOTE: one slight pessimization here is the fact that # there are a bunch of redundant views in the graph. # Technically, half of these views are duplicates that we could de-dup. # This shouldn't really hurt performance though, since creating an extra view # is effectively just moving some metadata around (and allocating a new TensorImpl). # We can/should update the pass in the future to clean this up. self.assertExpectedInline( f2.code, """\ def forward(self, a__1): clone_default = torch.ops.aten.clone.default(a__1); a__1 = None view_default = torch.ops.aten.view.default(clone_default, [-1]) view_default_1 = torch.ops.aten.view.default(clone_default, [-1]) select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1) view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0) view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None view_default_6 = torch.ops.aten.view.default(view_default_5, [-1]) add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None return view_default_5 """)
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): joint_forward_backward = create_joint_forward_backward(flat_fn) out = flat_fn(*flat_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out, ) if isinstance(out, (list, tuple)): _num_outs = len(out) else: _num_outs = 1 joint_inputs = (flat_args, out) fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs) if config.use_functionalize: # Functionalize the foward backward graph. First create a # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) if config.debug_joint: print(fx_g.code) with torch.no_grad(): with track_graph_compiling("joint"): fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs) if config.debug_graphs: print(fw_module.code, bw_module.code) with track_graph_compiling("forward"): compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args) if config.debug_partitioner: fw_outs = call_func_with_args(compiled_fw_func, flat_args) activation_sizes = 0 for out in fw_outs[_num_outs:]: if isinstance(out, torch.Tensor): activation_sizes += out.storage().nbytes() print(f"Real Activations Stored(GB): {activation_sizes/1e9}") class CompiledFunction(torch.autograd.Function): compiled_fw = compiled_fw_func compiled_bw = None num_outs = _num_outs @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): fw_outs = call_func_with_args( CompiledFunction.compiled_fw, flat_tensor_args ) num_outs = CompiledFunction.num_outs ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] all_args = list(ctx.saved_tensors) + list(contiguous_args) if CompiledFunction.compiled_bw is None: with track_graph_compiling("backward", True): CompiledFunction.compiled_bw = aot_config.bw_compiler( bw_module, all_args ) ctx.maybe_clear_saved_tensors() out = call_func_with_args( CompiledFunction.compiled_bw, all_args, steal_args=True ) return tuple(out) return CompiledFunction.apply