def test_multiple_functions(self): def f(x, bias): return x + bias def g(x, y): return x * y for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: functorch.compile.clear_compile_cache() aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type) aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type) start_num_recomps = functorch.compile.num_of_recompilations() a = torch.randn(10, requires_grad=True) b = torch.randn(10, requires_grad=True) self.check(a, b, aot_autograd_f, f) a = torch.randn(10, requires_grad=True) b = torch.randn(10, requires_grad=True) self.check(a, b, aot_autograd_g, g) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2 # Force recompilation for function f and check num of recompilations again a = torch.randn(10, 20, requires_grad=True) b = torch.randn(10, 20, requires_grad=True) self.check(a, b, aot_autograd_f, f) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 3
def test_multiple_compiler(self): def fn(x, bias): return x + bias def nop_duplicate(fx_g, _): return fx_g for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) nop_duplicate_fn = aot_function(fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type) a = torch.randn(10, 20, requires_grad=True) b = torch.randn(20, requires_grad=True) nop_fn(a, b) nop_duplicate_fn(a, b) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition): fw_graph_cell = [None] bw_graph_cell = [None] aot_function(f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=partitioner, decompositions=default_decompositions)(*inps) return (fw_graph_cell[0], bw_graph_cell[0])
def test_high_number_of_args(self): def f(*args): res = args[0] for arg in args: res = res * arg return res def check(args, aot_autograd_fn, fn): args_clone = [ arg.clone().detach().requires_grad_(True) for arg in args ] ref = fn(*args) ref.sum().backward() res = aot_autograd_fn(*args_clone) res.sum().backward() assert torch.allclose(res, ref) for (arg, arg_clone) in zip(args, args_clone): assert torch.allclose(arg.grad, arg_clone.grad) for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: functorch.compile.clear_compile_cache() aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type) args = [torch.randn(10, requires_grad=True) for _ in range(100)] check(args, aot_autograd_f, f)
def test_tuple_static_args(self): def fn(x, tuple_static_arg): return x * tuple_static_arg[0] * tuple_static_arg[1] functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1) a = torch.randn(2, 2, requires_grad=True) b = (2, 3) self.check(a, b, aot_autograd_f, fn) # Same type of args, so no recompilation a = torch.randn(2, 2, requires_grad=True) b = (2, 3) self.check(a, b, aot_autograd_f, fn) # Trigger recompilation a = torch.randn(2, 2, requires_grad=True) b = (3, 4) self.check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_if_condition(self): def fn(x, state: bool): if state: return torch.sin(x) else: return torch.cos(x) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) a = torch.randn(2, 2, requires_grad=True) b = True self.check(a, b, aot_autograd_f, fn) a = torch.randn(2, 2, requires_grad=True) b = True self.check(a, b, aot_autograd_f, fn) a = torch.randn(2, 2, requires_grad=True) b = False self.check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_custom(self): class Record: def __init__(self, name, multiplier): self.name = name self.multiplier = multiplier def __eq__(self, other): return self.name == other.name and self.multiplier == other.multiplier def __hash__(self): return hash((self.name, self.multiplier)) def fn(x, record): return x * record.multiplier functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) a = torch.randn(2, 2, requires_grad=True) b = Record("Foo", 0.5) self.check(a, b, aot_autograd_f, fn) a = torch.randn(2, 2, requires_grad=True) b = Record("Bar", 10.2) self.check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_interleaved_static_args(self): def fn(static_arg1, x, static_arg2): return static_arg1 - x - static_arg2 def check(a, b, c, aot_autograd_fn, fn): b_clone = b.clone().detach().requires_grad_(True) ref = fn(a, b, c) ref.sum().backward() res = aot_autograd_fn(a, b_clone, c) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(b.grad, b_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2)) a = 2 b = torch.randn(2, 2, requires_grad=True) c = 0.1 check(a, b, c, aot_autograd_f, fn) a = 3 b = torch.randn(2, 2, requires_grad=True) c = 0.1 check(a, b, c, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_static_arg_before_tensor_arg(self): def fn(static_arg, x): return static_arg - x def check(a, b, aot_autograd_fn, fn): b_clone = b.clone().detach().requires_grad_(True) ref = fn(a, b) ref.sum().backward() res = aot_autograd_fn(a, b_clone) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(b.grad, b_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0) a = 2 b = torch.randn(2, 2, requires_grad=True) check(a, b, aot_autograd_f, fn) a = 3 b = torch.randn(2, 2, requires_grad=True) check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def verify_aot_autograd(self, f, inp): if isinstance(f, nn.Module): compiled_f = aot_module(f, nop) else: compiled_f = aot_function(f, nop) ref_out, ref_grad = _outs_and_grads(f, inp) test_out, test_grad = _outs_and_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad)
def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass. # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 def f(x): return x.view(2, 3).t() inp = torch.randn(6, requires_grad=True) out = aot_function(f, nop)(inp) torch.autograd.grad(out, inp, torch.randn(3, 2))
def test_list_codegen(self): def list_nop(f, _): def g(inps): return f(*inps) g._boxed_call = True return g def f(a, b, c): return a.sin() * b.cos() * c.sin() f = aot_function(f, list_nop) inp = [torch.randn(5, requires_grad=True) for _ in range(3)] f(*inp).sum().backward()
def test_preserve_random(self): def fn(x): return torch.nn.functional.dropout(x, 0.5) + x x = torch.randn(4) torch.manual_seed(0) ref = fn(x) torch.manual_seed(0) aot_fn = aot_function(fn, nop) res = aot_fn(x) assert torch.allclose(ref, res)
def test_failure(self): # Test that not setting up static_argnums should raise exception def fn(x, p): return x * p aot_autograd_f = aot_function(fn, nop, nop) a = torch.randn(2, 2, requires_grad=True) b = 2 try: # Since b is not marked as static, it should raise exception aot_autograd_f(a, b) raise AssertionError() except RuntimeError: pass
def test_dict_with_static_arg_before_dict(self): def fn(static_arg, a_dict): return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg def check(a, b_dict, aot_autograd_fn, fn): ref = fn(a, b_dict) res = aot_autograd_fn(a, b_dict) assert torch.allclose(res, ref) b0 = b_dict["foo"] b1 = b_dict["bar"] b0_clone = b0.clone().detach().requires_grad_(True) b1_clone = b1.clone().detach().requires_grad_(True) ref.sum().backward() b_clone = {} b_clone["foo"] = b0_clone b_clone["bar"] = b1_clone res = aot_autograd_fn(a, b_clone) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(b0.grad, b0_clone.grad) assert torch.allclose(b1.grad, b1_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, )) a = 0.1 b = {} b["foo"] = torch.randn(2, 2, requires_grad=True) b["bar"] = torch.randn(2, 2, requires_grad=True) check(a, b, aot_autograd_f, fn) a = 0.2 b = {} b["foo"] = torch.randn(2, 2, requires_grad=True) b["bar"] = torch.randn(2, 2, requires_grad=True) check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_compilation_context(self): def f(x): return x.sin().sin() count = [] def compiler(fx_g, _): context = get_aot_compilation_context() count.append((context[0], len(fx_g.graph.nodes))) return fx_g f = aot_function(f, compiler) out = f(torch.randn(5, requires_grad=True)) f(torch.randn(5)) out.sum().backward() self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
def test_dict(self): def fn(a_dict, static_arg): return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg def check(a_dict, b, aot_autograd_fn, fn): a0 = a_dict["foo"] a1 = a_dict["bar"] a0_clone = a0.clone().detach().requires_grad_(True) a1_clone = a1.clone().detach().requires_grad_(True) ref = fn(a_dict, b) ref.sum().backward() a_clone = {} a_clone["foo"] = a0_clone a_clone["bar"] = a1_clone res = aot_autograd_fn(a_clone, b) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(a0.grad, a0_clone.grad) assert torch.allclose(a1.grad, a1_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1, )) a = {} a["foo"] = torch.zeros(2, 2, requires_grad=True) a["bar"] = torch.ones(2, 2, requires_grad=True) b = 0 check(a, b, aot_autograd_f, fn) a = {} a["foo"] = torch.randn(2, 2, requires_grad=True) a["bar"] = torch.randn(2, 2, requires_grad=True) b = 0.2 check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_tuple_with_first_arg_as_static(self): def fn(static_arg, a_tuple): return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg def check(a, b_tuple, aot_autograd_fn, fn): b0 = b_tuple[0] b1 = b_tuple[1] b0_clone = b0.clone().detach().requires_grad_(True) b1_clone = b1.clone().detach().requires_grad_(True) ref = fn(a, b_tuple) ref.sum().backward() res = aot_autograd_fn(a, (b0_clone, b1_clone)) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(b0.grad, b0_clone.grad) assert torch.allclose(b1.grad, b1_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, )) a = 0.1 b = ( torch.randn(2, 2, requires_grad=True), torch.randn(2, 2, requires_grad=True), ) check(a, b, aot_autograd_f, fn) a = 1 b = ( torch.randn(2, 2, requires_grad=True), torch.randn(2, 2, requires_grad=True), ) check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_tuple(self): def fn(a_tuple, static_arg): return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg def check(a_tuple, b, aot_autograd_fn, fn): a0 = a_tuple[0] a1 = a_tuple[1] a0_clone = a0.clone().detach().requires_grad_(True) a1_clone = a1.clone().detach().requires_grad_(True) ref = fn(a, b) ref.sum().backward() res = aot_autograd_fn((a0_clone, a1_clone), b) res.sum().backward() assert torch.allclose(res, ref) assert torch.allclose(a0.grad, a0_clone.grad) assert torch.allclose(a1.grad, a1_clone.grad) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1, )) a = ( torch.randn(2, 2, requires_grad=True), torch.randn(2, 2, requires_grad=True), ) b = 0.1 check(a, b, aot_autograd_f, fn) a = ( torch.randn(2, 2, requires_grad=True), torch.randn(2, 2, requires_grad=True), ) b = 1 check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_compilation_for_dynamic_shape(self): def fn(x, bias): return x + bias for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) for s in range(10, 20): a = torch.randn(s, requires_grad=True) b = torch.randn(s, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) for s in range(10, 20): a = torch.randn(s, requires_grad=True) b = torch.randn(s, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps if hasher_type == "DynamicShapeHasher": assert total_recomps == 1 elif hasher_type == "StaticShapeHasher": assert total_recomps == 10 for s in range(10, 20): a = torch.randn(s, s, requires_grad=True) b = torch.randn(s, s, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps if hasher_type == "DynamicShapeHasher": assert total_recomps == 2 elif hasher_type == "StaticShapeHasher": assert total_recomps == 20
def test_grad_context(self): def foo(x): return x * 2 inps = [torch.randn((), requires_grad=True)] graph_size = None def assert_graph_empty(fx_g, _): nonlocal graph_size graph_size = len(fx_g.graph.nodes) return fx_g start_recompilations = num_of_recompilations() f = aot_function(foo, nop, assert_graph_empty) with torch.set_grad_enabled(False): f(*inps) self.assertEqual(graph_size, 2) with torch.set_grad_enabled(True): f(*inps) self.assertTrue(graph_size > 2) self.assertEqual(num_of_recompilations() - start_recompilations, 2)
def test_recompilation_on_broadcast(self): def fn(x, bias): return x + bias for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) a = torch.randn(10, 20, requires_grad=True) b = torch.randn(20, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) a = torch.randn(10, 20, requires_grad=True) b = torch.randn(10, 20, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_dropout(self): def fn(x, prob): return torch.nn.functional.dropout(x, p=prob) functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) a = torch.randn(2, 2, requires_grad=True) b = 0.3 aot_autograd_f(a, b) # Setting the prob to 0. This should cause recompilation. a = torch.randn(2, 2, requires_grad=True) b = 0 self.check(a, b, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 2
def test_arg_none(self): def check(a, b, c, aot_autograd_fn, fn): def cloner(x): if x is not None: return x.clone().detach().requires_grad_(True) return None def check_grad(x, x_clone): if x is not None: return torch.allclose(x.grad, x_clone.grad) return True ref = fn(a, b, c) res = aot_autograd_fn(a, b, c) assert torch.allclose(res, ref) a_clone = cloner(a) b_clone = cloner(b) c_clone = cloner(c) ref.sum().backward() res = aot_autograd_fn(a_clone, b_clone, c_clone) res.sum().backward() check_grad(a, a_clone) check_grad(b, b_clone) check_grad(c, c_clone) def fn(a, b, c): if a is None and b is None: return c elif a is None and c is None: return b elif b is None and c is None: return a elif a is None: return b + c elif b is None: return a + c elif c is None: return a + b return a + b + c functorch.compile.clear_compile_cache() start_num_recomps = functorch.compile.num_of_recompilations() aot_autograd_f = aot_function(fn, nop, nop) t1 = torch.randn(2, 2, requires_grad=True) check(t1, None, None, aot_autograd_f, fn) check(None, t1, None, aot_autograd_f, fn) check(None, None, t1, aot_autograd_f, fn) t2 = torch.randn(2, 2, requires_grad=True) check(t1, t2, None, aot_autograd_f, fn) check(t1, None, t2, aot_autograd_f, fn) check(None, t1, t2, aot_autograd_f, fn) t3 = torch.randn(2, 2, requires_grad=True) check(t1, t2, t3, aot_autograd_f, fn) # Same type of args, so no recompilation check(t1, t2, None, aot_autograd_f, fn) end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps assert total_recomps == 7
def g(x, bias): return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias)