Esempio n. 1
0
def run_and_compare_activation(self, fn, inps):
    with torch.jit.fuser("fuser1"):
        device = "cuda"
        dtype = torch.float
        if isinstance(fn, nn.Module):
            fn = fn.to(device=device, dtype=dtype)

        ref_args = [
            torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
            for shape in inps
        ]
        res_args = [i.clone().detach().requires_grad_(True) for i in ref_args]

        ref = fn(*ref_args)
        ref.sum().backward()

        mem_optimized_fn = memory_efficient_fusion(fn)
        for _ in range(5):
            for i in res_args:
                i.grad = None
            res = mem_optimized_fn(*res_args)
            res.sum().backward()

        self.assertEqual(ref, res)
        for ref_arg, res_arg in zip(ref_args, res_args):
            self.assertEqual(ref_arg.grad, res_arg.grad)
Esempio n. 2
0
    def test_autocast(self):
        mod = torchvision.models.resnet18().cuda()
        mod.train()

        x = torch.randn(16, 3, 32, 32, device="cuda")
        aot_mod = memory_efficient_fusion(mod)

        # Ensure that AOT Autograd works with AMP
        with torch.cuda.amp.autocast(True):
            res = aot_mod(x)
        res.sum().backward()
Esempio n. 3
0
    # Clear the compile cache
    clear_compile_cache()

    # Get the function and inputs
    obj = cl()
    fn = obj.fn
    args = obj.args()

    # Find the static args
    static_argnums = []
    for idx, arg in enumerate(args):
        if not isinstance(arg, torch.Tensor):
            static_argnums.append(idx)

    # Get the optimized function
    opt_fn = memory_efficient_fusion(fn, static_argnums)

    # Profile cuda kernels
    benchmark_helper.profile_cuda_kernels(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd")

    # Time it with Torch Timer
    benchmark_helper.time_with_torch_timer(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd")

    # Time it with manual Timer
    benchmark_helper.time_with_manual_timer(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd")