def test_sliced_stride(self):
     @torch.jit.script
     def test(x, y, z):
         return x + y + z
     llvm = LLVMCodeGenExecuted()
     interp = SimpleIREvalExecuted()
     x = torch.rand(16, 4, 2, 3)[::2]
     y = torch.rand(8, 4, 2, 3)
     z = torch.rand(8, 4, 2, 3)
     ref = test(x, y, z)
     res = test(x, y, z)
     np.testing.assert_allclose(ref.numpy(), res.numpy())
     assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
 def test_transpose(self):
     @torch.jit.script
     def test(x, y, z):
         return x.transpose(0, 1) + y + z
     llvm = LLVMCodeGenExecuted()
     interp = SimpleIREvalExecuted()
     x = torch.rand(4, 5, 2, 3)
     y = torch.rand(5, 4, 2, 3)
     z = torch.rand(5, 4, 2, 3)
     ref = test(x, y, z)
     res = test(x, y, z)
     np.testing.assert_allclose(ref.numpy(), res.numpy())
     assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
    def test_unsqueeze(self):
        def easy(x, y):
            a = torch.unsqueeze(x, 0)
            b = torch.unsqueeze(y, 0)
            return a + b

        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))

        llvm = LLVMCodeGenExecuted()
        interp = SimpleIREvalExecuted()

        a = torch.rand(1024, 1024)
        x = traced(a, a)
        npr = np.expand_dims(a, 0)
        npr = npr + npr
        np.testing.assert_allclose(npr, x.numpy())
        assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
    def test_slice(self):
        def easy(x, y):
            a = x[0:512:2]
            b = y[0:512:2]
            return a + b

        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))

        llvm = LLVMCodeGenExecuted()
        interp = SimpleIREvalExecuted()

        a = torch.ones(1024, 1024)
        x = traced(a, a)
        npr = a[0:512:2]
        npr = npr + npr
        np.testing.assert_allclose(npr.numpy(), x.numpy())
        assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
    def test_three_arg(self):
        llvm_executed = LLVMCodeGenExecuted()
        simple_ir_eval_executed = SimpleIREvalExecuted()

        def easy(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.add(aaa, z)
            return bbb

        traced = torch.jit.trace(
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)))

        a = torch.rand(1024)
        b = torch.rand(1024)
        c = torch.rand(1024)
        x = traced(a, b, c)
        npr = a.numpy() + b.numpy() + c.numpy()
        np.testing.assert_allclose(npr, x.numpy())
        assert (llvm_executed.elapsed_value() >= 1
                or simple_ir_eval_executed.elapsed_value() >= 1)
    def test_scalar(self):
        @torch.jit.script
        def test_float(x, y, z, a, b):
            # type: (Tensor, Tensor, Tensor, float, float) -> Tensor
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)

        @torch.jit.script
        def test_int(x, y, z, a, b):
            # type: (Tensor, Tensor, Tensor, int, int) -> Tensor
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)

        for test in (test_float, test_int):
            llvm = LLVMCodeGenExecuted()
            interp = SimpleIREvalExecuted()
            x, y, z = [torch.rand(4) for i in range(3)]
            a, b = 1, 2
            test(x, y, z, a, b)
            r = test(x, y, z, a, b)
            xn, yn, zn = [t.numpy() for t in (x, y, z)]
            np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b)
            assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1