def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) if compiled_fw is None: flat_tensor_args = pytree.tree_map( lambda x: x.detach().requires_grad_(x.requires_grad) if isinstance(x, Tensor) else x, flat_tensor_args) fake_mode = FakeTensorMode.push( ) if config.use_fake_tensor else nullcontext() with preserve_rng_state(), fake_mode as mode: # Set input tensors that require grad to leaves fake_flat_tensor_args = pytree.tree_map( lambda x: mode.from_tensor(x) if mode else x if isinstance(x, Tensor) else x, flat_tensor_args) with torch.set_grad_enabled(grad_state): out = flat_fn(*fake_flat_tensor_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out) if isinstance(out, (list, tuple)): num_outs = len(out) else: num_outs = 1 joint_inputs = (fake_flat_tensor_args, out) aot_decompositions = { **aot_autograd_decompositions, **decompositions } with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)(*joint_inputs) if config.use_functionalize: # Functionalize the foward backward graph. First create a # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) fx_g = make_fx( functionalize(fake_fn))(*joint_inputs) fw_module, bw_module = partition_fn(fx_g, joint_inputs) # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) torch._C._jit_set_autocast_mode(old_jit_autocast_flag) ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs])
def test_make_fx_no_decompose(self, device): # FIXME return self.skipTest("error: maximum recursion reached") def f(x): return torch.tanh(x).sum() fx_f = make_fx(grad(f))(torch.randn(5)) ops = set([i.target for i in fx_f.graph.nodes]) self.assertEqual(torch.ops.aten.tanh_backward in ops, True) fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) ops = set([i.target for i in fx_f.graph.nodes]) self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
def test_scalar_device(self, device): def f(a, b): return a + b inps = [torch.randn(3, device=device), torch.tensor(5)] fx_f = make_fx(f)(*inps) self.assertEqual(fx_f(*inps), f(*inps))
def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out) if isinstance(out, (list, tuple)): num_outs = len(out) else: num_outs = 1 joint_inputs = (flat_tensor_args, out) aot_decompositions = { **aot_autograd_decompositions, **decompositions } with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)(*joint_inputs) fw_module, bw_module = partition_fn(fx_g, joint_inputs) # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs])
def test_make_fx_jacrev(self, device): def f(x): return x.sin().sum() inp = torch.randn(3) f = jacrev(jacrev(f)) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vmap(self, device): def f(x): return torch.sin(x) inp = torch.randn(5, 3) f = vmap(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(5, 3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_grad(self, device): def f(x): return torch.sin(x).sum() inp = torch.randn(3) f = grad(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vjp(self, device): def f(x): return torch.sin(x).sum() primals = torch.randn(3) _, vjp_fn = vjp(f, primals) cotangent = torch.randn(()) fx_f = make_fx(vjp_fn)(cotangent, True, True) new_cotangent = torch.randn(()) self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) with track_graph_compiling("inference"): compiled_fw = aot_config.fw_compiler(fw_module, flat_args) @wraps(compiled_fw) def new_fn(args): fw_outs = call_func_with_args(compiled_fw, args) return fw_outs return new_fn
def profile_function(name, f, inp): fx_g = make_fx(f)(inp) new_g = fx_graph_cse(fx_g.graph) new_g = fx.GraphModule(fx_g, new_g) # do not benchmark against the scripted version because script already does some CSE # script_f = torch.jit.script(fx_g) # script_g = torch.jit.script(new_g) # avg_cuda_time_f = profile_it(script_f, inp) # avg_cuda_time_g = profile_it(script_g, inp) avg_cuda_time_f = profile_it(fx_g, inp) avg_cuda_time_g = profile_it(new_g, inp) num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
def test_tup_use(self): def f(a, b): tup = torch.std_mean(a) return (tup[0] + b * tup[1], ) inps = [torch.randn(3), torch.randn(3)] def has_add(fx_g, inps): return (torch.ops.aten.add.Tensor in set([i.target for i in fx_g.graph.nodes])) failing_f = make_fx(f)(*inps) min_f, inps = minifier(failing_f, inps, has_add) self.assertEqual(len(min_f.graph.nodes), 4) self.assertEqual(len(inps), 2)
def test_has_mul_minifier(self): def failing_f(x, y): y = y / 3 x = x + 3 x = x * y return x + y inps = [torch.randn(3), torch.randn(3)] failing_f = make_fx(failing_f)(*inps) def has_mul(fx_g, inps): return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes])) min_f, inps = minifier(failing_f, inps, has_mul) self.assertEqual(len(min_f.graph.nodes), 4) self.assertEqual(len(inps), 2)
def test_has_mul_minifier(self): def failing_f(x, y): y = y / 3 x = x + 3 x = x * y return x + y inps = [torch.randn(3), torch.randn(3)] failing_f = make_fx(failing_f)(*inps) def pass_checker(fx_g, inps): return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes])) min_f, inps = minifier(failing_f, inps, pass_checker) assert len(min_f.graph.nodes) == 4 assert len(inps) == 2
def test_input_returned(self): def f(a, b, c): a = a.sin() c = c.cos() d = a * c return (a, b, c, d) inps = [torch.randn(3) for _ in range(3)] def inputs_returned(fx_g, inps): inps = set(get_placeholders(fx_g.graph)) outs = set(get_outputs(fx_g.graph)) return len(inps & outs) > 0 failing_f = make_fx(f)(*inps) min_f, inps = minifier(failing_f, inps, inputs_returned) self.assertEqual(len(min_f.graph.nodes), 2) self.assertEqual(len(inps), 1)
def generate_graph(model, inputs, training_fn): # TODO: Pass the decomposition_table according to the model/needs. fx_g = make_fx(training_fn, decomposition_table=get_decompositions([ torch.ops.aten.embedding_dense_backward, torch.ops.aten.native_layer_norm_backward, torch.ops.aten.slice_backward, torch.ops.aten.select_backward ]))(dict(model.named_parameters()), dict(model.named_buffers()), inputs) fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) fx_g.recompile() fx_g = change_fx_graph_return_to_tuple(fx_g) ts_g = torch.jit.script(fx_g) # TODO: If not saved/load creates some unnecessary functions that # causes problem during mlir-translate. temp = tempfile.NamedTemporaryFile(suffix='_heavy_dep', prefix='temp_ts_') ts_g.save(temp.name) new_ts = torch.jit.load(temp.name) return new_ts
def test_has_add_mul(self): def failing_f(x): x = x * 3 x = x + 5 x = x.cos() zero = x - x result = zero / zero result = result + 3 return (result * 2, ) inps = [torch.randn(3)] failing_f = make_fx(failing_f)(*inps) def pass_checker(fx_g, inps): # Basically, make sure none of the inputs are nans for i in inps: if torch.isnan(i).any(): return False return torch.isnan(fx_g(*inps)[0]).any() min_f, inps = minifier(failing_f, inps, pass_checker) assert len(min_f.graph.nodes) == 3 assert len(inps) == 1
def test_make_fx_exhaustive(self, device, dtype, op): def f(args, kwargs): return op.op(*args, **kwargs) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) new_f = None for sample_input in sample_inputs_itr: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs new_f = make_fx(f)(args, kwargs) for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: old_out = f(args, kwargs) except Exception: continue new_out = new_f(args, kwargs) self.assertEqual(new_out, old_out) pass
def test_has_add_mul(self): def failing_f(x): x = x * 3 x = x + 5 x = x.cos() zero = x - x result = zero / zero result = result + 3 return (result * 2, ) inps = [torch.randn(3)] failing_f = make_fx(failing_f)(*inps) def has_nans(fx_g, inps): # Basically, make sure none of the nodes are computing nans for i in inps: if torch.isnan(i).any(): return False return torch.isnan(fx_g(*inps)[0]).any() min_f, inps = minifier(failing_f, inps, has_nans) self.assertEqual(len(min_f.graph.nodes), 3) self.assertEqual(len(inps), 1)
def check(f, t, delta, check_val=True, graph_input=False): if graph_input: fx_g = f else: fx_g = make_fx(f)(t) new_graph = fx_graph_cse(fx_g.graph) new_g = fx.GraphModule(fx_g, new_graph) # the number of nodes decrease/ or stay the same old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) if delta == -1: assert old_num_nodes >= new_num_nodes, ( f"number of nodes increased {old_num_nodes}, {new_num_nodes}") else: assert old_num_nodes == new_num_nodes + delta, ( f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" ) # a second pass should not reduce more nodes pass_2_graph = fx_graph_cse(new_graph) pass_2_num_nodes = len(pass_2_graph.nodes) assert pass_2_num_nodes == new_num_nodes, ( f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" ) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None assert our_result is None, f"true result is None, CSE result is {our_result}" else: # results returned are the same assert torch.all(true_result == our_result), ( f"results are different {true_result}, {our_result}" ) # check results are the same
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): joint_forward_backward = create_joint_forward_backward(flat_fn) out = flat_fn(*flat_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out, ) if isinstance(out, (list, tuple)): _num_outs = len(out) else: _num_outs = 1 joint_inputs = (flat_args, out) fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs) if config.use_functionalize: # Functionalize the foward backward graph. First create a # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) if config.debug_joint: print(fx_g.code) with torch.no_grad(): with track_graph_compiling("joint"): fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs) if config.debug_graphs: print(fw_module.code, bw_module.code) with track_graph_compiling("forward"): compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args) if config.debug_partitioner: fw_outs = call_func_with_args(compiled_fw_func, flat_args) activation_sizes = 0 for out in fw_outs[_num_outs:]: if isinstance(out, torch.Tensor): activation_sizes += out.storage().nbytes() print(f"Real Activations Stored(GB): {activation_sizes/1e9}") class CompiledFunction(torch.autograd.Function): compiled_fw = compiled_fw_func compiled_bw = None num_outs = _num_outs @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): fw_outs = call_func_with_args( CompiledFunction.compiled_fw, flat_tensor_args ) num_outs = CompiledFunction.num_outs ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] all_args = list(ctx.saved_tensors) + list(contiguous_args) if CompiledFunction.compiled_bw is None: with track_graph_compiling("backward", True): CompiledFunction.compiled_bw = aot_config.bw_compiler( bw_module, all_args ) ctx.maybe_clear_saved_tensors() out = call_func_with_args( CompiledFunction.compiled_bw, all_args, steal_args=True ) return tuple(out) return CompiledFunction.apply
from functorch import grad, nnc_jit, make_fx, make_nnc import torch import time def f(x): return torch.sin(x).sum() inp = torch.randn(100) grad_pt = grad(f) grad_fx = make_fx(grad_pt)(inp) grad_nnc = nnc_jit(grad_pt) loopnest = make_nnc(grad_pt)(inp) print(loopnest) def bench(name, f, iters=10000, warmup=3): for _ in range(warmup): f() begin = time.time() for _ in range(iters): f() print(f"{name}: ", time.time() - begin) bench("Pytorch: ", lambda: grad_pt(inp)) bench("FX: ", lambda: grad_fx(inp)) bench("NNC: ", lambda: grad_nnc(inp))