def run_and_compare_activation(self, fn, inps): with torch.jit.fuser("fuser1"): device = "cuda" dtype = torch.float if isinstance(fn, nn.Module): fn = fn.to(device=device, dtype=dtype) ref_args = [ torch.randn(shape, device=device, dtype=dtype, requires_grad=True) for shape in inps ] res_args = [i.clone().detach().requires_grad_(True) for i in ref_args] ref = fn(*ref_args) ref.sum().backward() mem_optimized_fn = memory_efficient_fusion(fn) for _ in range(5): for i in res_args: i.grad = None res = mem_optimized_fn(*res_args) res.sum().backward() self.assertEqual(ref, res) for ref_arg, res_arg in zip(ref_args, res_args): self.assertEqual(ref_arg.grad, res_arg.grad)
def test_autocast(self): mod = torchvision.models.resnet18().cuda() mod.train() x = torch.randn(16, 3, 32, 32, device="cuda") aot_mod = memory_efficient_fusion(mod) # Ensure that AOT Autograd works with AMP with torch.cuda.amp.autocast(True): res = aot_mod(x) res.sum().backward()
# Clear the compile cache clear_compile_cache() # Get the function and inputs obj = cl() fn = obj.fn args = obj.args() # Find the static args static_argnums = [] for idx, arg in enumerate(args): if not isinstance(arg, torch.Tensor): static_argnums.append(idx) # Get the optimized function opt_fn = memory_efficient_fusion(fn, static_argnums) # Profile cuda kernels benchmark_helper.profile_cuda_kernels(fn, args, "Eager") with torch.jit.fuser("fuser2"): benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd") # Time it with Torch Timer benchmark_helper.time_with_torch_timer(fn, args, "Eager") with torch.jit.fuser("fuser2"): benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd") # Time it with manual Timer benchmark_helper.time_with_manual_timer(fn, args, "Eager") with torch.jit.fuser("fuser2"): benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd")