Пример #1
0
        def test_body(M, N, L, K):
            if not torch.cuda.is_available():
                return
            cuda_cg_executed = CudaCodeGenExecuted()
            cuda_cg_created = CudaCodeGenCreated()

            def test(x, y, z):
                v1 = torch.add(x, y)
                v2 = torch.add(v1, z)
                return v2

            a_shape = [M, N]
            b_shape = [L, M, 1]
            c_shape = [K, L, 1, 1]
            traced = torch.jit.trace(
                test,
                (
                    torch.rand(*a_shape, device="cuda"),
                    torch.rand(*b_shape, device="cuda"),
                    torch.rand(*c_shape, device="cuda"),
                ),
            )

            a = torch.rand(*a_shape, device="cuda")
            b = torch.rand(*b_shape, device="cuda")
            c = torch.rand(*c_shape, device="cuda")
            x = traced(a, b, c)
            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
            np.testing.assert_allclose(npr, x.cpu().numpy())
            assert cuda_cg_executed.elapsed_value() >= 1
            assert cuda_cg_created.elapsed_value() >= 1
Пример #2
0
    def test_three_arg_cuda(self):
        if not torch.cuda.is_available():
            return
        cuda_cg_executed = CudaCodeGenExecuted()
        cuda_cg_created = CudaCodeGenCreated()

        def test(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.add(aaa, z)
            return bbb

        M = 32
        N = 32
        traced = torch.jit.trace(
            test,
            (
                torch.rand(M, N, device="cuda"),
                torch.rand(M, N, device="cuda"),
                torch.rand(M, N, device="cuda"),
            ),
        )

        a = torch.rand(M, N, device="cuda")
        b = torch.rand(M, N, device="cuda")
        c = torch.rand(M, N, device="cuda")
        x = traced(a, b, c)
        npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
        np.testing.assert_allclose(npr, x.cpu().numpy())
        assert cuda_cg_executed.elapsed_value() >= 1
        assert cuda_cg_created.elapsed_value() >= 1
Пример #3
0
 def test_multi_rand(self):
     def test(x):
         y = torch.rand_like(x)
         return (x + y) - (y - x)
     a = torch.rand(4, device="cuda")
     scripted = torch.jit.script(test)
     scripted(a)
     cx = CudaCodeGenExecuted()
     assert torch.allclose(scripted(a), 2 * a)
     assert cx.elapsed_value() == 1
Пример #4
0
 def test_unused(self):
     def test(x, y):
         return x * x + torch.rand_like(y)
     a = torch.rand(1, device="cuda")
     b = torch.rand(1, device="cuda")
     scripted = torch.jit.script(test)
     scripted(a, b)
     cx = CudaCodeGenExecuted()
     scripted(a, b)
     assert cx.elapsed_value() == 1
Пример #5
0
    def test_rand_diamond(self):
        def fn_test_diamond(x, y):
            r = torch.rand_like(y)
            a = x + r
            b = y - r
            return a + b

        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
        script_f = torch.jit.script(fn_test_diamond)
        warmup_forward(script_f, x, y)
        cx = CudaCodeGenExecuted()
        out = script_f(x, y)
        assert cx.elapsed_value() == 1
        self.assertEqual(out, x + y)
Пример #6
0
    def test_rand_cuda(self):
        class M(torch.jit.ScriptModule):
            __constants__ = ['d']

            def __init__(self):
                super(M, self).__init__()
                self.d = torch.device('cuda')

            @torch.jit.script_method
            def create(self, x):
                return x * x + x + torch.rand_like(x)

        x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
        m = M()
        out1 = m.create(x)
        cx = CudaCodeGenExecuted()
        out2 = m.create(x)
        assert cx.elapsed_value() == 1
        self.assertNotEqual(out1, out2)
        self.assertTrue(torch.all(out1 >= 0))
        self.assertTrue(torch.all(out1 < 1))
        self.assertTrue(torch.all(out2 >= 0))
        self.assertTrue(torch.all(out2 < 1))
        self.assertAllFused(m.create.graph_for(x))
Пример #7
0
 def test_guard_fails(self):
     @torch.jit.script
     def test(x, y, z):
         return x * y * z
     cuda = CudaCodeGenExecuted()
     r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
     assert cuda.elapsed_value() == 0
     r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
     assert cuda.elapsed_value() == 1
     r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
     assert cuda.elapsed_value() == 2
     r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
     assert cuda.elapsed_value() == 2