def test_body(M, N, L, K): if not torch.cuda.is_available(): return cuda_cg_executed = CudaCodeGenExecuted() cuda_cg_created = CudaCodeGenCreated() def test(x, y, z): v1 = torch.add(x, y) v2 = torch.add(v1, z) return v2 a_shape = [M, N] b_shape = [L, M, 1] c_shape = [K, L, 1, 1] traced = torch.jit.trace( test, ( torch.rand(*a_shape, device="cuda"), torch.rand(*b_shape, device="cuda"), torch.rand(*c_shape, device="cuda"), ), ) a = torch.rand(*a_shape, device="cuda") b = torch.rand(*b_shape, device="cuda") c = torch.rand(*c_shape, device="cuda") x = traced(a, b, c) npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 assert cuda_cg_created.elapsed_value() >= 1
def test_three_arg_cuda(self): if not torch.cuda.is_available(): return cuda_cg_executed = CudaCodeGenExecuted() cuda_cg_created = CudaCodeGenCreated() def test(x, y, z): aaa = torch.add(x, y) bbb = torch.add(aaa, z) return bbb M = 32 N = 32 traced = torch.jit.trace( test, ( torch.rand(M, N, device="cuda"), torch.rand(M, N, device="cuda"), torch.rand(M, N, device="cuda"), ), ) a = torch.rand(M, N, device="cuda") b = torch.rand(M, N, device="cuda") c = torch.rand(M, N, device="cuda") x = traced(a, b, c) npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 assert cuda_cg_created.elapsed_value() >= 1
def test_multi_rand(self): def test(x): y = torch.rand_like(x) return (x + y) - (y - x) a = torch.rand(4, device="cuda") scripted = torch.jit.script(test) scripted(a) cx = CudaCodeGenExecuted() assert torch.allclose(scripted(a), 2 * a) assert cx.elapsed_value() == 1
def test_unused(self): def test(x, y): return x * x + torch.rand_like(y) a = torch.rand(1, device="cuda") b = torch.rand(1, device="cuda") scripted = torch.jit.script(test) scripted(a, b) cx = CudaCodeGenExecuted() scripted(a, b) assert cx.elapsed_value() == 1
def test_rand_diamond(self): def fn_test_diamond(x, y): r = torch.rand_like(y) a = x + r b = y - r return a + b x = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') script_f = torch.jit.script(fn_test_diamond) warmup_forward(script_f, x, y) cx = CudaCodeGenExecuted() out = script_f(x, y) assert cx.elapsed_value() == 1 self.assertEqual(out, x + y)
def test_rand_cuda(self): class M(torch.jit.ScriptModule): __constants__ = ['d'] def __init__(self): super(M, self).__init__() self.d = torch.device('cuda') @torch.jit.script_method def create(self, x): return x * x + x + torch.rand_like(x) x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda') m = M() out1 = m.create(x) cx = CudaCodeGenExecuted() out2 = m.create(x) assert cx.elapsed_value() == 1 self.assertNotEqual(out1, out2) self.assertTrue(torch.all(out1 >= 0)) self.assertTrue(torch.all(out1 < 1)) self.assertTrue(torch.all(out2 >= 0)) self.assertTrue(torch.all(out2 < 1)) self.assertAllFused(m.create.graph_for(x))
def test_guard_fails(self): @torch.jit.script def test(x, y, z): return x * y * z cuda = CudaCodeGenExecuted() r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) assert cuda.elapsed_value() == 0 r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) assert cuda.elapsed_value() == 1 r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) assert cuda.elapsed_value() == 2 r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) assert cuda.elapsed_value() == 2