def test_sliced_stride(self): @torch.jit.script def test(x, y, z): return x + y + z llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() x = torch.rand(16, 4, 2, 3)[::2] y = torch.rand(8, 4, 2, 3) z = torch.rand(8, 4, 2, 3) ref = test(x, y, z) res = test(x, y, z) np.testing.assert_allclose(ref.numpy(), res.numpy()) assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
def test_transpose(self): @torch.jit.script def test(x, y, z): return x.transpose(0, 1) + y + z llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() x = torch.rand(4, 5, 2, 3) y = torch.rand(5, 4, 2, 3) z = torch.rand(5, 4, 2, 3) ref = test(x, y, z) res = test(x, y, z) np.testing.assert_allclose(ref.numpy(), res.numpy()) assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
def test_unsqueeze(self): def easy(x, y): a = torch.unsqueeze(x, 0) b = torch.unsqueeze(y, 0) return a + b traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() a = torch.rand(1024, 1024) x = traced(a, a) npr = np.expand_dims(a, 0) npr = npr + npr np.testing.assert_allclose(npr, x.numpy()) assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
def test_slice(self): def easy(x, y): a = x[0:512:2] b = y[0:512:2] return a + b traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() a = torch.ones(1024, 1024) x = traced(a, a) npr = a[0:512:2] npr = npr + npr np.testing.assert_allclose(npr.numpy(), x.numpy()) assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1
def test_three_arg(self): llvm_executed = LLVMCodeGenExecuted() simple_ir_eval_executed = SimpleIREvalExecuted() def easy(x, y, z): aaa = torch.add(x, y) bbb = torch.add(aaa, z) return bbb traced = torch.jit.trace( easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))) a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) x = traced(a, b, c) npr = a.numpy() + b.numpy() + c.numpy() np.testing.assert_allclose(npr, x.numpy()) assert (llvm_executed.elapsed_value() >= 1 or simple_ir_eval_executed.elapsed_value() >= 1)
def test_scalar(self): @torch.jit.script def test_float(x, y, z, a, b): # type: (Tensor, Tensor, Tensor, float, float) -> Tensor return torch.add(torch.add(x, y, alpha=a), z, alpha=b) @torch.jit.script def test_int(x, y, z, a, b): # type: (Tensor, Tensor, Tensor, int, int) -> Tensor return torch.add(torch.add(x, y, alpha=a), z, alpha=b) for test in (test_float, test_int): llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() x, y, z = [torch.rand(4) for i in range(3)] a, b = 1, 2 test(x, y, z, a, b) r = test(x, y, z, a, b) xn, yn, zn = [t.numpy() for t in (x, y, z)] np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b) assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1